From 42e47e5a3fa7219b136f6a5de7c74a89a79245c6 Mon Sep 17 00:00:00 2001 From: Mike Kruskal <62662355+mkruskal-google@users.noreply.github.com> Date: Thu, 29 Sep 2022 10:17:52 -0700 Subject: [PATCH] Refactoring Java parsing (3.16.x) (#10668) * Update changelog * Porting java cleanup * Extension patch * Remove extra allocations * Fixing merge issues * More merge fixes * Fix TextFormat parser --- CHANGES.txt | 19 +- .../com/google/protobuf/AbstractMessage.java | 27 +- .../com/google/protobuf/ArrayDecoders.java | 146 +++--- .../com/google/protobuf/BinaryReader.java | 32 +- .../protobuf/CodedInputStreamReader.java | 51 +- .../DescriptorMessageInfoFactory.java | 4 +- .../com/google/protobuf/DynamicMessage.java | 5 +- .../com/google/protobuf/ExtensionSchema.java | 1 + .../google/protobuf/ExtensionSchemaFull.java | 5 +- .../google/protobuf/ExtensionSchemaLite.java | 52 +- .../java/com/google/protobuf/FieldSet.java | 63 ++- .../google/protobuf/GeneratedMessageLite.java | 167 ++++-- .../google/protobuf/GeneratedMessageV3.java | 110 +++- .../google/protobuf/MessageLiteToString.java | 117 +++-- .../google/protobuf/MessageReflection.java | 439 +++++++++++++++- .../com/google/protobuf/MessageSchema.java | 485 +++++++++++------- .../com/google/protobuf/MessageSetSchema.java | 17 +- .../protobuf/NewInstanceSchemaLite.java | 5 +- .../main/java/com/google/protobuf/Reader.java | 8 + .../java/com/google/protobuf/SchemaUtil.java | 30 +- .../java/com/google/protobuf/TextFormat.java | 2 +- .../google/protobuf/UnknownFieldSetLite.java | 38 +- .../protobuf/UnknownFieldSetLiteSchema.java | 12 +- .../java/com/google/protobuf/LiteTest.java | 36 +- .../google/protobuf/util/FieldMaskTree.java | 10 +- .../protobuf/compiler/java/java_enum_field.cc | 94 ++-- .../protobuf/compiler/java/java_enum_field.h | 94 ++-- .../protobuf/compiler/java/java_field.cc | 4 +- .../protobuf/compiler/java/java_field.h | 5 +- .../protobuf/compiler/java/java_map_field.cc | 30 +- .../protobuf/compiler/java/java_map_field.h | 34 +- .../protobuf/compiler/java/java_message.cc | 220 ++------ .../compiler/java/java_message_builder.cc | 163 ++++-- .../compiler/java/java_message_builder.h | 5 + .../compiler/java/java_message_field.cc | 126 ++--- .../compiler/java/java_message_field.h | 87 ++-- .../compiler/java/java_primitive_field.cc | 66 +-- .../compiler/java/java_primitive_field.h | 90 ++-- .../compiler/java/java_string_field.cc | 59 +-- .../compiler/java/java_string_field.h | 87 ++-- 40 files changed, 1923 insertions(+), 1122 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 1391c7eb0306..887619f73885 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,21 @@ -2022-09-13 version 16.2 (C++/Java/Python/PHP/Objective-C/C#/Ruby) +2022-09-27 version 3.16.3 (C++/Java/Python/PHP/Objective-C/C#/Ruby) + + Java + * Refactoring java full runtime to reuse sub-message builders and prepare to + migrate parsing logic from parse constructor to builder. + * Move proto wireformat parsing functionality from the private "parsing + constructor" to the Builder class. + * Change the Lite runtime to prefer merging from the wireformat into mutable + messages rather than building up a new immutable object before merging. This + way results in fewer allocations and copy operations. + * Make message-type extensions merge from wire-format instead of building up + instances and merging afterwards. This has much better performance. + * Fix TextFormat parser to build up recurring (but supposedly not repeated) + sub-messages directly from text rather than building a new sub-message and + merging the fully formed message into the existing field. + + +2022-09-13 version 3.16.2 (C++/Java/Python/PHP/Objective-C/C#/Ruby) C++ * Reduce memory consumption of MessageSet parsing diff --git a/java/core/src/main/java/com/google/protobuf/AbstractMessage.java b/java/core/src/main/java/com/google/protobuf/AbstractMessage.java index 1364fce41ec4..ebf4318b9342 100644 --- a/java/core/src/main/java/com/google/protobuf/AbstractMessage.java +++ b/java/core/src/main/java/com/google/protobuf/AbstractMessage.java @@ -426,27 +426,22 @@ public BuilderType mergeFrom( throws IOException { boolean discardUnknown = input.shouldDiscardUnknownFields(); final UnknownFieldSet.Builder unknownFields = - discardUnknown ? null : UnknownFieldSet.newBuilder(getUnknownFields()); - while (true) { - final int tag = input.readTag(); - if (tag == 0) { - break; - } - - MessageReflection.BuilderAdapter builderAdapter = - new MessageReflection.BuilderAdapter(this); - if (!MessageReflection.mergeFieldFrom( - input, unknownFields, extensionRegistry, getDescriptorForType(), builderAdapter, tag)) { - // end group tag - break; - } - } + discardUnknown ? null : getUnknownFieldSetBuilder(); + MessageReflection.mergeMessageFrom(this, unknownFields, input, extensionRegistry); if (unknownFields != null) { - setUnknownFields(unknownFields.build()); + setUnknownFieldSetBuilder(unknownFields); } return (BuilderType) this; } + protected UnknownFieldSet.Builder getUnknownFieldSetBuilder() { + return UnknownFieldSet.newBuilder(getUnknownFields()); + } + + protected void setUnknownFieldSetBuilder(final UnknownFieldSet.Builder builder) { + setUnknownFields(builder.build()); + } + @Override public BuilderType mergeUnknownFields(final UnknownFieldSet unknownFields) { setUnknownFields( diff --git a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java index 1217e112e0ce..39b79278c7b9 100644 --- a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java +++ b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java @@ -234,6 +234,29 @@ static int decodeBytes(byte[] data, int position, Registers registers) @SuppressWarnings({"unchecked", "rawtypes"}) static int decodeMessageField( Schema schema, byte[] data, int position, int limit, Registers registers) throws IOException { + Object msg = schema.newInstance(); + int offset = mergeMessageField(msg, schema, data, position, limit, registers); + schema.makeImmutable(msg); + registers.object1 = msg; + return offset; + } + + /** Decodes a group value. */ + @SuppressWarnings({"unchecked", "rawtypes"}) + static int decodeGroupField( + Schema schema, byte[] data, int position, int limit, int endGroup, Registers registers) + throws IOException { + Object msg = schema.newInstance(); + int offset = mergeGroupField(msg, schema, data, position, limit, endGroup, registers); + schema.makeImmutable(msg); + registers.object1 = msg; + return offset; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + static int mergeMessageField( + Object msg, Schema schema, byte[] data, int position, int limit, Registers registers) + throws IOException { int length = data[position++]; if (length < 0) { position = decodeVarint32(length, data, position, registers); @@ -242,27 +265,28 @@ static int decodeMessageField( if (length < 0 || length > limit - position) { throw InvalidProtocolBufferException.truncatedMessage(); } - Object result = schema.newInstance(); - schema.mergeFrom(result, data, position, position + length, registers); - schema.makeImmutable(result); - registers.object1 = result; + schema.mergeFrom(msg, data, position, position + length, registers); + registers.object1 = msg; return position + length; } - /** Decodes a group value. */ @SuppressWarnings({"unchecked", "rawtypes"}) - static int decodeGroupField( - Schema schema, byte[] data, int position, int limit, int endGroup, Registers registers) + static int mergeGroupField( + Object msg, + Schema schema, + byte[] data, + int position, + int limit, + int endGroup, + Registers registers) throws IOException { // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema // and it can't be used in group fields). final MessageSchema messageSchema = (MessageSchema) schema; - Object result = messageSchema.newInstance(); // It's OK to directly use parseProto2Message since proto3 doesn't have group. final int endPosition = - messageSchema.parseProto2Message(result, data, position, limit, endGroup, registers); - messageSchema.makeImmutable(result); - registers.object1 = result; + messageSchema.parseProto2Message(msg, data, position, limit, endGroup, registers); + registers.object1 = msg; return endPosition; } @@ -847,26 +871,19 @@ static int decodeExtension( break; } case ENUM: - { - IntArrayList list = new IntArrayList(); - position = decodePackedVarint32List(data, position, list, registers); - UnknownFieldSetLite unknownFields = message.unknownFields; - if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) { - unknownFields = null; - } - unknownFields = - SchemaUtil.filterUnknownEnumList( - fieldNumber, - list, - extension.descriptor.getEnumType(), - unknownFields, - unknownFieldSchema); - if (unknownFields != null) { - message.unknownFields = unknownFields; + { + IntArrayList list = new IntArrayList(); + position = decodePackedVarint32List(data, position, list, registers); + SchemaUtil.filterUnknownEnumList( + message, + fieldNumber, + list, + extension.descriptor.getEnumType(), + null, + unknownFieldSchema); + extensions.setField(extension.descriptor, list); + break; } - extensions.setField(extension.descriptor, list); - break; - } default: throw new IllegalStateException( "Type cannot be packed: " + extension.descriptor.getLiteType()); @@ -878,13 +895,8 @@ static int decodeExtension( position = decodeVarint32(data, position, registers); Object enumValue = extension.descriptor.getEnumType().findValueByNumber(registers.int1); if (enumValue == null) { - UnknownFieldSetLite unknownFields = ((GeneratedMessageLite) message).unknownFields; - if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) { - unknownFields = UnknownFieldSetLite.newInstance(); - ((GeneratedMessageLite) message).unknownFields = unknownFields; - } SchemaUtil.storeUnknownEnum( - fieldNumber, registers.int1, unknownFields, unknownFieldSchema); + message, fieldNumber, registers.int1, null, unknownFieldSchema); return position; } // Note, we store the integer value instead of the actual enum object in FieldSet. @@ -941,20 +953,45 @@ static int decodeExtension( value = registers.object1; break; case GROUP: - final int endTag = (fieldNumber << 3) | WireFormat.WIRETYPE_END_GROUP; - position = decodeGroupField( - Protobuf.getInstance().schemaFor(extension.getMessageDefaultInstance().getClass()), - data, position, limit, endTag, registers); - value = registers.object1; - break; - + { + final int endTag = (fieldNumber << 3) | WireFormat.WIRETYPE_END_GROUP; + final Schema fieldSchema = + Protobuf.getInstance() + .schemaFor(extension.getMessageDefaultInstance().getClass()); + if (extension.isRepeated()) { + position = decodeGroupField(fieldSchema, data, position, limit, endTag, registers); + extensions.addRepeatedField(extension.descriptor, registers.object1); + } else { + Object oldValue = extensions.getField(extension.descriptor); + if (oldValue == null) { + oldValue = fieldSchema.newInstance(); + extensions.setField(extension.descriptor, oldValue); + } + position = + mergeGroupField( + oldValue, fieldSchema, data, position, limit, endTag, registers); + } + return position; + } case MESSAGE: - position = decodeMessageField( - Protobuf.getInstance().schemaFor(extension.getMessageDefaultInstance().getClass()), - data, position, limit, registers); - value = registers.object1; - break; - + { + final Schema fieldSchema = + Protobuf.getInstance() + .schemaFor(extension.getMessageDefaultInstance().getClass()); + if (extension.isRepeated()) { + position = decodeMessageField(fieldSchema, data, position, limit, registers); + extensions.addRepeatedField(extension.descriptor, registers.object1); + } else { + Object oldValue = extensions.getField(extension.descriptor); + if (oldValue == null) { + oldValue = fieldSchema.newInstance(); + extensions.setField(extension.descriptor, oldValue); + } + position = + mergeMessageField(oldValue, fieldSchema, data, position, limit, registers); + } + return position; + } case ENUM: throw new IllegalStateException("Shouldn't reach here."); } @@ -962,17 +999,6 @@ static int decodeExtension( if (extension.isRepeated()) { extensions.addRepeatedField(extension.descriptor, value); } else { - switch (extension.getLiteType()) { - case MESSAGE: - case GROUP: - Object oldValue = extensions.getField(extension.descriptor); - if (oldValue != null) { - value = Internal.mergeMessage(oldValue, value); - } - break; - default: - break; - } extensions.setField(extension.descriptor, value); } } diff --git a/java/core/src/main/java/com/google/protobuf/BinaryReader.java b/java/core/src/main/java/com/google/protobuf/BinaryReader.java index d64574c2a581..3a0e04dfa238 100644 --- a/java/core/src/main/java/com/google/protobuf/BinaryReader.java +++ b/java/core/src/main/java/com/google/protobuf/BinaryReader.java @@ -247,6 +247,15 @@ public T readMessageBySchemaWithCheck( private T readMessage(Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { + T newInstance = schema.newInstance(); + mergeMessageField(newInstance, schema, extensionRegistry); + schema.makeImmutable(newInstance); + return newInstance; + } + + @Override + public void mergeMessageField( + T target, Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { int size = readVarint32(); requireBytes(size); @@ -256,15 +265,10 @@ private T readMessage(Schema schema, ExtensionRegistryLite extensionRegis limit = newLimit; try { - // Allocate and read the message. - T message = schema.newInstance(); - schema.mergeFrom(message, this, extensionRegistry); - schema.makeImmutable(message); - + schema.mergeFrom(target, this, extensionRegistry); if (pos != newLimit) { throw InvalidProtocolBufferException.parseFailure(); } - return message; } finally { // Restore the limit. limit = prevLimit; @@ -287,19 +291,23 @@ public T readGroupBySchemaWithCheck( private T readGroup(Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { + T newInstance = schema.newInstance(); + mergeGroupField(newInstance, schema, extensionRegistry); + schema.makeImmutable(newInstance); + return newInstance; + } + + @Override + public void mergeGroupField( + T target, Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { int prevEndGroupTag = endGroupTag; endGroupTag = WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WIRETYPE_END_GROUP); try { - // Allocate and read the message. - T message = schema.newInstance(); - schema.mergeFrom(message, this, extensionRegistry); - schema.makeImmutable(message); - + schema.mergeFrom(target, this, extensionRegistry); if (tag != endGroupTag) { throw InvalidProtocolBufferException.parseFailure(); } - return message; } finally { // Restore the old end group tag. endGroupTag = prevEndGroupTag; diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStreamReader.java b/java/core/src/main/java/com/google/protobuf/CodedInputStreamReader.java index 7658f629d371..1d992d75d068 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStreamReader.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStreamReader.java @@ -197,9 +197,15 @@ public T readGroupBySchemaWithCheck(Schema schema, ExtensionRegistryLite return readGroup(schema, extensionRegistry); } - // Should have the same semantics of CodedInputStream#readMessage() - private T readMessage(Schema schema, ExtensionRegistryLite extensionRegistry) - throws IOException { + @Override + public void mergeMessageField( + T target, Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { + requireWireType(WIRETYPE_LENGTH_DELIMITED); + mergeMessageFieldInternal(target, schema, extensionRegistry); + } + + private void mergeMessageFieldInternal( + T target, Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { int size = input.readUInt32(); if (input.recursionDepth >= input.recursionLimit) { throw InvalidProtocolBufferException.recursionLimitExceeded(); @@ -207,39 +213,54 @@ private T readMessage(Schema schema, ExtensionRegistryLite extensionRegis // Push the new limit. final int prevLimit = input.pushLimit(size); - // Allocate and read the message. - T message = schema.newInstance(); ++input.recursionDepth; - schema.mergeFrom(message, this, extensionRegistry); - schema.makeImmutable(message); + schema.mergeFrom(target, this, extensionRegistry); input.checkLastTagWas(0); --input.recursionDepth; // Restore the previous limit. input.popLimit(prevLimit); - return message; } - private T readGroup(Schema schema, ExtensionRegistryLite extensionRegistry) + // Should have the same semantics of CodedInputStream#readMessage() + private T readMessage(Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { + T newInstance = schema.newInstance(); + mergeMessageFieldInternal(newInstance, schema, extensionRegistry); + schema.makeImmutable(newInstance); + return newInstance; + } + + @Override + public void mergeGroupField( + T target, Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { + requireWireType(WIRETYPE_START_GROUP); + mergeGroupFieldInternal(target, schema, extensionRegistry); + } + + private void mergeGroupFieldInternal( + T target, Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException { int prevEndGroupTag = endGroupTag; endGroupTag = WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WIRETYPE_END_GROUP); try { - // Allocate and read the message. - T message = schema.newInstance(); - schema.mergeFrom(message, this, extensionRegistry); - schema.makeImmutable(message); - + schema.mergeFrom(target, this, extensionRegistry); if (tag != endGroupTag) { throw InvalidProtocolBufferException.parseFailure(); } - return message; } finally { // Restore the old end group tag. endGroupTag = prevEndGroupTag; } } + private T readGroup(Schema schema, ExtensionRegistryLite extensionRegistry) + throws IOException { + T newInstance = schema.newInstance(); + mergeGroupFieldInternal(newInstance, schema, extensionRegistry); + schema.makeImmutable(newInstance); + return newInstance; + } + @Override public ByteString readBytes() throws IOException { requireWireType(WIRETYPE_LENGTH_DELIMITED); diff --git a/java/core/src/main/java/com/google/protobuf/DescriptorMessageInfoFactory.java b/java/core/src/main/java/com/google/protobuf/DescriptorMessageInfoFactory.java index 7975136596a8..21ded52d95ae 100644 --- a/java/core/src/main/java/com/google/protobuf/DescriptorMessageInfoFactory.java +++ b/java/core/src/main/java/com/google/protobuf/DescriptorMessageInfoFactory.java @@ -402,8 +402,8 @@ private static StructuralMessageInfo convertProto3( boolean enforceUtf8 = true; for (int i = 0; i < fieldDescriptors.size(); ++i) { FieldDescriptor fd = fieldDescriptors.get(i); - if (fd.getContainingOneof() != null) { - // Build a oneof member field. + if (fd.getContainingOneof() != null && !fd.getContainingOneof().isSynthetic()) { + // Build a oneof member field. But only if it is a real oneof, not a proto3 optional builder.withField(buildOneofMember(messageType, fd, oneofState, enforceUtf8, null)); continue; } diff --git a/java/core/src/main/java/com/google/protobuf/DynamicMessage.java b/java/core/src/main/java/com/google/protobuf/DynamicMessage.java index 8beebba24d65..51e6b0c27c1f 100644 --- a/java/core/src/main/java/com/google/protobuf/DynamicMessage.java +++ b/java/core/src/main/java/com/google/protobuf/DynamicMessage.java @@ -421,7 +421,10 @@ public DynamicMessage buildPartial() { fields.makeImmutable(); DynamicMessage result = new DynamicMessage( - type, fields, java.util.Arrays.copyOf(oneofCases, oneofCases.length), unknownFields); + type, + fields, + java.util.Arrays.copyOf(oneofCases, oneofCases.length), + unknownFields); return result; } diff --git a/java/core/src/main/java/com/google/protobuf/ExtensionSchema.java b/java/core/src/main/java/com/google/protobuf/ExtensionSchema.java index 2eae22d26a07..bd391a2c15db 100644 --- a/java/core/src/main/java/com/google/protobuf/ExtensionSchema.java +++ b/java/core/src/main/java/com/google/protobuf/ExtensionSchema.java @@ -59,6 +59,7 @@ abstract class ExtensionSchema> { * or UnknownFieldSetLite in lite runtime. */ abstract UB parseExtension( + Object containerMessage, Reader reader, Object extension, ExtensionRegistryLite extensionRegistry, diff --git a/java/core/src/main/java/com/google/protobuf/ExtensionSchemaFull.java b/java/core/src/main/java/com/google/protobuf/ExtensionSchemaFull.java index 90558518b996..9376e87800dd 100644 --- a/java/core/src/main/java/com/google/protobuf/ExtensionSchemaFull.java +++ b/java/core/src/main/java/com/google/protobuf/ExtensionSchemaFull.java @@ -85,6 +85,7 @@ void makeImmutable(Object message) { @Override UB parseExtension( + Object containerMessage, Reader reader, Object extensionObject, ExtensionRegistryLite extensionRegistry, @@ -202,7 +203,7 @@ UB parseExtension( } else { unknownFields = SchemaUtil.storeUnknownEnum( - fieldNumber, number, unknownFields, unknownFieldSchema); + containerMessage, fieldNumber, number, unknownFields, unknownFieldSchema); } } value = enumList; @@ -221,7 +222,7 @@ UB parseExtension( Object enumValue = extension.descriptor.getEnumType().findValueByNumber(number); if (enumValue == null) { return SchemaUtil.storeUnknownEnum( - fieldNumber, number, unknownFields, unknownFieldSchema); + containerMessage, fieldNumber, number, unknownFields, unknownFieldSchema); } value = enumValue; } else { diff --git a/java/core/src/main/java/com/google/protobuf/ExtensionSchemaLite.java b/java/core/src/main/java/com/google/protobuf/ExtensionSchemaLite.java index 437cca2d96bd..7e20ed2ff6bc 100644 --- a/java/core/src/main/java/com/google/protobuf/ExtensionSchemaLite.java +++ b/java/core/src/main/java/com/google/protobuf/ExtensionSchemaLite.java @@ -32,7 +32,6 @@ import com.google.protobuf.GeneratedMessageLite.ExtensionDescriptor; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -67,6 +66,7 @@ void makeImmutable(Object message) { @Override UB parseExtension( + Object containerMessage, Reader reader, Object extensionObject, ExtensionRegistryLite extensionRegistry, @@ -178,6 +178,7 @@ UB parseExtension( reader.readEnumList(list); unknownFields = SchemaUtil.filterUnknownEnumList( + containerMessage, fieldNumber, list, extension.descriptor.getEnumType(), @@ -199,7 +200,7 @@ UB parseExtension( Object enumValue = extension.descriptor.getEnumType().findValueByNumber(number); if (enumValue == null) { return SchemaUtil.storeUnknownEnum( - fieldNumber, number, unknownFields, unknownFieldSchema); + containerMessage, fieldNumber, number, unknownFields, unknownFieldSchema); } // Note, we store the integer value instead of the actual enum object in FieldSet. // This is also different from full-runtime where we store EnumValueDescriptor. @@ -253,12 +254,46 @@ UB parseExtension( value = reader.readString(); break; case GROUP: + // Special case handling for non-repeated sub-messages: merge in-place rather than + // building up new sub-messages and merging those, which is too slow. + // TODO(b/249368670): clean this up + if (!extension.isRepeated()) { + Object oldValue = extensions.getField(extension.descriptor); + if (oldValue instanceof GeneratedMessageLite) { + Schema extSchema = Protobuf.getInstance().schemaFor(oldValue); + if (!((GeneratedMessageLite) oldValue).isMutable()) { + Object newValue = extSchema.newInstance(); + extSchema.mergeFrom(newValue, oldValue); + extensions.setField(extension.descriptor, newValue); + oldValue = newValue; + } + reader.mergeGroupField(oldValue, extSchema, extensionRegistry); + return unknownFields; + } + } value = reader.readGroup( extension.getMessageDefaultInstance().getClass(), extensionRegistry); break; case MESSAGE: + // Special case handling for non-repeated sub-messages: merge in-place rather than + // building up new sub-messages and merging those, which is too slow. + // TODO(b/249368670): clean this up + if (!extension.isRepeated()) { + Object oldValue = extensions.getField(extension.descriptor); + if (oldValue instanceof GeneratedMessageLite) { + Schema extSchema = Protobuf.getInstance().schemaFor(oldValue); + if (!((GeneratedMessageLite) oldValue).isMutable()) { + Object newValue = extSchema.newInstance(); + extSchema.mergeFrom(newValue, oldValue); + extensions.setField(extension.descriptor, newValue); + oldValue = newValue; + } + reader.mergeMessageField(oldValue, extSchema, extensionRegistry); + return unknownFields; + } + } value = reader.readMessage( extension.getMessageDefaultInstance().getClass(), extensionRegistry); @@ -274,6 +309,7 @@ UB parseExtension( switch (extension.getLiteType()) { case MESSAGE: case GROUP: + // TODO(b/249368670): this shouldn't be reachable, clean this up Object oldValue = extensions.getField(extension.descriptor); if (oldValue != null) { value = Internal.mergeMessage(oldValue, value); @@ -527,15 +563,13 @@ void parseMessageSetItem( throws IOException { GeneratedMessageLite.GeneratedExtension extension = (GeneratedMessageLite.GeneratedExtension) extensionObject; - Object value = extension.getMessageDefaultInstance().newBuilderForType().buildPartial(); - Reader reader = BinaryReader.newInstance(ByteBuffer.wrap(data.toByteArray()), true); + MessageLite.Builder builder = extension.getMessageDefaultInstance().newBuilderForType(); - Protobuf.getInstance().mergeFrom(value, reader, extensionRegistry); - extensions.setField(extension.descriptor, value); + final CodedInputStream input = data.newCodedInput(); - if (reader.getFieldNumber() != Reader.READ_DONE) { - throw InvalidProtocolBufferException.invalidEndTag(); - } + builder.mergeFrom(input, extensionRegistry); + extensions.setField(extension.descriptor, builder.buildPartial()); + input.checkLastTagWas(0); } } diff --git a/java/core/src/main/java/com/google/protobuf/FieldSet.java b/java/core/src/main/java/com/google/protobuf/FieldSet.java index f64b50a839ec..aebfefbe92d3 100644 --- a/java/core/src/main/java/com/google/protobuf/FieldSet.java +++ b/java/core/src/main/java/com/google/protobuf/FieldSet.java @@ -39,6 +39,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; /** * A class which represents an arbitrary set of fields of some message type. This is used to @@ -123,6 +124,12 @@ public void makeImmutable() { if (isImmutable) { return; } + for (int i = 0; i < fields.getNumArrayEntries(); ++i) { + Entry entry = fields.getArrayEntryAt(i); + if (entry.getValue() instanceof GeneratedMessageLite) { + ((GeneratedMessageLite) entry.getValue()).makeImmutable(); + } + } fields.makeImmutable(); isImmutable = true; } @@ -938,8 +945,27 @@ private Builder(SmallSortedMap fields) { this.isMutable = true; } - /** Creates the FieldSet */ + /** + * Creates the FieldSet + * + * @throws UninitializedMessageException if a message field is missing required fields. + */ public FieldSet build() { + return buildImpl(false); + } + + /** Creates the FieldSet but does not validate that all required fields are present. */ + public FieldSet buildPartial() { + return buildImpl(true); + } + + /** + * Creates the FieldSet. + * + * @param partial controls whether to do a build() or buildPartial() when converting submessage + * builders to messages. + */ + private FieldSet buildImpl(boolean partial) { if (fields.isEmpty()) { return FieldSet.emptySet(); } @@ -948,7 +974,7 @@ public FieldSet build() { if (hasNestedBuilders) { // Make a copy of the fields map with all Builders replaced by Message. fieldsForBuild = cloneAllFieldsMap(fields, /* copyList */ false); - replaceBuilders(fieldsForBuild); + replaceBuilders(fieldsForBuild, partial); } FieldSet fieldSet = new FieldSet<>(fieldsForBuild); fieldSet.hasLazyField = hasLazyField; @@ -956,22 +982,22 @@ public FieldSet build() { } private static > void replaceBuilders( - SmallSortedMap fieldMap) { + SmallSortedMap fieldMap, boolean partial) { for (int i = 0; i < fieldMap.getNumArrayEntries(); i++) { - replaceBuilders(fieldMap.getArrayEntryAt(i)); + replaceBuilders(fieldMap.getArrayEntryAt(i), partial); } for (Map.Entry entry : fieldMap.getOverflowEntries()) { - replaceBuilders(entry); + replaceBuilders(entry, partial); } } private static > void replaceBuilders( - Map.Entry entry) { - entry.setValue(replaceBuilders(entry.getKey(), entry.getValue())); + Map.Entry entry, boolean partial) { + entry.setValue(replaceBuilders(entry.getKey(), entry.getValue(), partial)); } private static > Object replaceBuilders( - T descriptor, Object value) { + T descriptor, Object value, boolean partial) { if (value == null) { return value; } @@ -986,7 +1012,7 @@ private static > Object replaceBuilders( List list = (List) value; for (int i = 0; i < list.size(); i++) { Object oldElement = list.get(i); - Object newElement = replaceBuilder(oldElement); + Object newElement = replaceBuilder(oldElement, partial); if (newElement != oldElement) { // If the list contains a Message.Builder, then make a copy of that list and then // modify the Message.Builder into a Message and return the new list. This way, the @@ -1000,14 +1026,21 @@ private static > Object replaceBuilders( } return list; } else { - return replaceBuilder(value); + return replaceBuilder(value, partial); } } return value; } - private static Object replaceBuilder(Object value) { - return (value instanceof MessageLite.Builder) ? ((MessageLite.Builder) value).build() : value; + private static Object replaceBuilder(Object value, boolean partial) { + if (!(value instanceof MessageLite.Builder)) { + return value; + } + MessageLite.Builder builder = (MessageLite.Builder) value; + if (partial) { + return builder.buildPartial(); + } + return builder.build(); } /** Returns a new Builder using the fields from {@code fieldSet}. */ @@ -1026,7 +1059,7 @@ public Map getAllFields() { if (fields.isImmutable()) { result.makeImmutable(); } else { - replaceBuilders(result); + replaceBuilders(result, true); } return result; } @@ -1049,7 +1082,7 @@ public boolean hasField(final T descriptor) { */ public Object getField(final T descriptor) { Object value = getFieldAllowBuilders(descriptor); - return replaceBuilders(descriptor, value); + return replaceBuilders(descriptor, value, true); } /** Same as {@link #getField(F)}, but allow a {@link MessageLite.Builder} to be returned. */ @@ -1136,7 +1169,7 @@ public Object getRepeatedField(final T descriptor, final int index) { ensureIsMutable(); } Object value = getRepeatedFieldAllowBuilders(descriptor, index); - return replaceBuilder(value); + return replaceBuilder(value, true); } /** diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java index 7db8f32ee019..808fde8e762f 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java @@ -62,11 +62,50 @@ public abstract class GeneratedMessageLite< BuilderType extends GeneratedMessageLite.Builder> extends AbstractMessageLite { + /* For use by lite runtime only */ + static final int UNINITIALIZED_SERIALIZED_SIZE = 0x7FFFFFFF; + private static final int MUTABLE_FLAG_MASK = 0x80000000; + private static final int MEMOIZED_SERIALIZED_SIZE_MASK = 0x7FFFFFFF; + + /** + * We use the high bit of memoizedSerializedSize as the explicit mutability flag. It didn't make + * sense to have negative sizes anyway. Messages start as mutable. + * + *

Adding a standalone boolean would have added 8 bytes to every message instance. + * + *

We also reserve 0x7FFFFFFF as the "uninitialized" value. + */ + private int memoizedSerializedSize = MUTABLE_FLAG_MASK | UNINITIALIZED_SERIALIZED_SIZE; + + /* For use by the runtime only */ + static final int UNINITIALIZED_HASH_CODE = 0; + /** For use by generated code only. Lazily initialized to reduce allocations. */ protected UnknownFieldSetLite unknownFields = UnknownFieldSetLite.getDefaultInstance(); - /** For use by generated code only. */ - protected int memoizedSerializedSize = -1; + boolean isMutable() { + return (memoizedSerializedSize & MUTABLE_FLAG_MASK) != 0; + } + + void markImmutable() { + memoizedSerializedSize &= ~MUTABLE_FLAG_MASK; + } + + int getMemoizedHashCode() { + return memoizedHashCode; + } + + void setMemoizedHashCode(int value) { + memoizedHashCode = value; + } + + void clearMemoizedHashCode() { + memoizedHashCode = UNINITIALIZED_HASH_CODE; + } + + boolean hashCodeIsNotMemoized() { + return UNINITIALIZED_HASH_CODE == getMemoizedHashCode(); + } @Override @SuppressWarnings("unchecked") // Guaranteed by runtime. @@ -86,6 +125,10 @@ public final BuilderType newBuilderForType() { return (BuilderType) dynamicMethod(MethodToInvoke.NEW_BUILDER); } + MessageType newMutableInstance() { + return (MessageType) dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); + } + /** * A reflective toString function. This is primarily intended as a developer aid, while keeping * binary size down. The first line of the {@code toString()} representation includes a commented @@ -106,11 +149,19 @@ public String toString() { @SuppressWarnings("unchecked") // Guaranteed by runtime @Override public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; + if (isMutable()) { + return computeHashCode(); } - memoizedHashCode = Protobuf.getInstance().schemaFor(this).hashCode(this); - return memoizedHashCode; + + if (hashCodeIsNotMemoized()) { + setMemoizedHashCode(computeHashCode()); + } + + return getMemoizedHashCode(); + } + + int computeHashCode() { + return Protobuf.getInstance().schemaFor(this).hashCode(this); } @SuppressWarnings("unchecked") // Guaranteed by isInstance + runtime @@ -173,6 +224,7 @@ protected void mergeLengthDelimitedField(int fieldNumber, ByteString value) { /** Called by subclasses to complete parsing. For use by generated code only. */ protected void makeImmutable() { Protobuf.getInstance().schemaFor(this).makeImmutable(this); + markImmutable(); } protected final < @@ -198,8 +250,7 @@ public final boolean isInitialized() { @SuppressWarnings("unchecked") public final BuilderType toBuilder() { BuilderType builder = (BuilderType) dynamicMethod(MethodToInvoke.NEW_BUILDER); - builder.mergeFrom((MessageType) this); - return builder; + return builder.mergeFrom((MessageType) this); } /** @@ -256,14 +307,22 @@ protected Object dynamicMethod(MethodToInvoke method) { return dynamicMethod(method, null, null); } + void clearMemoizedSerializedSize() { + setMemoizedSerializedSize(UNINITIALIZED_SERIALIZED_SIZE); + } + @Override int getMemoizedSerializedSize() { - return memoizedSerializedSize; + return memoizedSerializedSize & MEMOIZED_SERIALIZED_SIZE_MASK; } @Override void setMemoizedSerializedSize(int size) { - memoizedSerializedSize = size; + if (size < 0) { + throw new IllegalStateException("serialized size must be non-negative, was " + size); + } + memoizedSerializedSize = + (memoizedSerializedSize & MUTABLE_FLAG_MASK) | (size & MEMOIZED_SERIALIZED_SIZE_MASK); } @Override @@ -273,12 +332,42 @@ public void writeTo(CodedOutputStream output) throws IOException { .writeTo(this, CodedOutputStreamWriter.forCodedOutput(output)); } + @Override + int getSerializedSize(Schema schema) { + if (isMutable()) { + // The serialized size should never be memoized for mutable instances. + int size = computeSerializedSize(schema); + if (size < 0) { + throw new IllegalStateException("serialized size must be non-negative, was " + size); + } + return size; + } + + // If memoizedSerializedSize has already been set, return it. + if (getMemoizedSerializedSize() != UNINITIALIZED_SERIALIZED_SIZE) { + return getMemoizedSerializedSize(); + } + + // Need to compute and memoize the serialized size. + int size = computeSerializedSize(schema); + setMemoizedSerializedSize(size); + return size; + } + @Override public int getSerializedSize() { - if (memoizedSerializedSize == -1) { - memoizedSerializedSize = Protobuf.getInstance().schemaFor(this).getSerializedSize(this); + // Calling this with 'null' to delay schema lookup in case the serializedSize is already + // memoized. + return getSerializedSize(null); + } + + private int computeSerializedSize(Schema nullableSchema) { + if (nullableSchema == null) { + return Protobuf.getInstance().schemaFor(this).getSerializedSize(this); + } else { + return ((Schema>) nullableSchema) + .getSerializedSize(this); } - return memoizedSerializedSize; } /** Constructs a {@link MessageInfo} for this message type. */ @@ -318,6 +407,7 @@ Object buildMessageInfo() throws Exception { protected static > void registerDefaultInstance( Class clazz, T defaultInstance) { defaultInstanceMap.put(clazz, defaultInstance); + defaultInstance.makeImmutable(); } protected static Object newMessageInfo( @@ -342,13 +432,19 @@ public abstract static class Builder< private final MessageType defaultInstance; protected MessageType instance; - protected boolean isBuilt; protected Builder(MessageType defaultInstance) { this.defaultInstance = defaultInstance; - this.instance = - (MessageType) defaultInstance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); - isBuilt = false; + if (defaultInstance.isMutable()) { + throw new IllegalArgumentException("Default instance must be immutable."); + } + // this.instance should be set to defaultInstance but some tests rely on newBuilder().build() + // creating unique instances. + this.instance = newMutableInstance(); + } + + private MessageType newMutableInstance() { + return defaultInstance.newMutableInstance(); } /** @@ -356,15 +452,13 @@ protected Builder(MessageType defaultInstance) { * state before the write happens to preserve immutability guarantees. */ protected final void copyOnWrite() { - if (isBuilt) { + if (!instance.isMutable()) { copyOnWriteInternal(); - isBuilt = false; } } protected void copyOnWriteInternal() { - MessageType newInstance = - (MessageType) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); + MessageType newInstance = newMutableInstance(); mergeFromInstance(newInstance, instance); instance = newInstance; } @@ -376,27 +470,28 @@ public final boolean isInitialized() { @Override public final BuilderType clear() { - // No need to copy on write since we're dropping the instance anyways. - instance = (MessageType) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); + // No need to copy on write since we're dropping the instance anyway. + if (defaultInstance.isMutable()) { + throw new IllegalArgumentException("Default instance must be immutable."); + } + instance = newMutableInstance(); // should be defaultInstance; return (BuilderType) this; } @Override public BuilderType clone() { BuilderType builder = (BuilderType) getDefaultInstanceForType().newBuilderForType(); - builder.mergeFrom(buildPartial()); + builder.instance = buildPartial(); return builder; } @Override public MessageType buildPartial() { - if (isBuilt) { + if (!instance.isMutable()) { return instance; } instance.makeImmutable(); - - isBuilt = true; return instance; } @@ -416,12 +511,15 @@ protected BuilderType internalMergeFrom(MessageType message) { /** All subclasses implement this. */ public BuilderType mergeFrom(MessageType message) { + if (getDefaultInstanceForType().equals(message)) { + return (BuilderType) this; + } copyOnWrite(); mergeFromInstance(instance, message); return (BuilderType) this; } - private void mergeFromInstance(MessageType dest, MessageType src) { + private static void mergeFromInstance(MessageType dest, MessageType src) { Protobuf.getInstance().schemaFor(dest).mergeFrom(dest, src); } @@ -932,7 +1030,9 @@ void internalSetExtensionSet(FieldSet extensions) { @Override protected void copyOnWriteInternal() { super.copyOnWriteInternal(); - instance.extensions = instance.extensions.clone(); + if (instance.extensions != FieldSet.emptySet()) { + instance.extensions = instance.extensions.clone(); + } } private FieldSet ensureExtensionsAreMutable() { @@ -946,7 +1046,7 @@ private FieldSet ensureExtensionsAreMutable() { @Override public final MessageType buildPartial() { - if (isBuilt) { + if (!instance.isMutable()) { return instance; } @@ -1530,7 +1630,7 @@ public T parsePartialFrom( T instance, CodedInputStream input, ExtensionRegistryLite extensionRegistry) throws InvalidProtocolBufferException { @SuppressWarnings("unchecked") // Guaranteed by protoc - T result = (T) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); + T result = instance.newMutableInstance(); try { // TODO(yilunchong): Try to make input with type CodedInpuStream.ArrayDecoder use // fast path. @@ -1561,15 +1661,12 @@ public T parsePartialFrom( T instance, byte[] input, int offset, int length, ExtensionRegistryLite extensionRegistry) throws InvalidProtocolBufferException { @SuppressWarnings("unchecked") // Guaranteed by protoc - T result = (T) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); + T result = instance.newMutableInstance(); try { Schema schema = Protobuf.getInstance().schemaFor(result); schema.mergeFrom( result, input, offset, offset + length, new ArrayDecoders.Registers(extensionRegistry)); schema.makeImmutable(result); - if (result.memoizedHashCode != 0) { - throw new RuntimeException(); - } } catch (InvalidProtocolBufferException e) { if (e.getThrownFromInputStream()) { e = new InvalidProtocolBufferException(e); diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java index 86f88a0228f2..e32260406858 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java @@ -133,6 +133,10 @@ public Descriptor getDescriptorForType() { return internalGetFieldAccessorTable().descriptor; } + // TODO(b/248143958): This method should be removed. It enables parsing directly into an + // "immutable" message. Have to leave it for now to support old gencode. + // @deprecated use newBuilder().mergeFrom() instead + @Deprecated protected void mergeFromAndMakeImmutableInternal( CodedInputStream input, ExtensionRegistryLite extensionRegistry) throws InvalidProtocolBufferException { @@ -299,13 +303,14 @@ public Object getRepeatedField(final FieldDescriptor field, final int index) { @Override public UnknownFieldSet getUnknownFields() { - throw new UnsupportedOperationException( - "This is supposed to be overridden by subclasses."); + return unknownFields; } /** * Called by subclasses to parse an unknown field. * + *

TODO(b/248153893) remove this method + * * @return {@code true} unless the tag is an end-group tag. */ protected boolean parseUnknownField( @@ -323,6 +328,8 @@ protected boolean parseUnknownField( /** * Delegates to parseUnknownField. This method is obsolete, but we must retain it for * compatibility with older generated code. + * + *

TODO(b/248153893) remove this method */ protected boolean parseUnknownFieldProto3( CodedInputStream input, @@ -547,8 +554,18 @@ public abstract static class Builder > // to dispatch dirty invalidations. See GeneratedMessageV3.BuilderListener. private boolean isClean; - private UnknownFieldSet unknownFields = - UnknownFieldSet.getDefaultInstance(); + /** + * This field holds either an {@link UnknownFieldSet} or {@link UnknownFieldSet.Builder}. + * + *

We use an object because it should only be one or the other of those things at a time and + * Object is the only common base. This also saves space. + * + *

Conversions are lazy: if {@link #setUnknownFields} is called, this will contain {@link + * UnknownFieldSet}. If unknown fields are merged into this builder, the current {@link + * UnknownFieldSet} will be converted to a {@link UnknownFieldSet.Builder} and left that way + * until either {@link #setUnknownFields} or {@link #buildPartial} or {@link #build} is called. + */ + private Object unknownFieldsOrBuilder = UnknownFieldSet.getDefaultInstance(); protected Builder() { this(null); @@ -604,7 +621,7 @@ public BuilderType clone() { */ @Override public BuilderType clear() { - unknownFields = UnknownFieldSet.getDefaultInstance(); + unknownFieldsOrBuilder = UnknownFieldSet.getDefaultInstance(); onChanged(); return (BuilderType) this; } @@ -757,7 +774,7 @@ public BuilderType addRepeatedField(final FieldDescriptor field, final Object va } private BuilderType setUnknownFieldsInternal(final UnknownFieldSet unknownFields) { - this.unknownFields = unknownFields; + unknownFieldsOrBuilder = unknownFields; onChanged(); return (BuilderType) this; } @@ -776,12 +793,20 @@ protected BuilderType setUnknownFieldsProto3(final UnknownFieldSet unknownFields } @Override - public BuilderType mergeUnknownFields( - final UnknownFieldSet unknownFields) { - return setUnknownFields( - UnknownFieldSet.newBuilder(this.unknownFields) - .mergeFrom(unknownFields) - .build()); + public BuilderType mergeUnknownFields(final UnknownFieldSet unknownFields) { + if (UnknownFieldSet.getDefaultInstance().equals(unknownFields)) { + return (BuilderType) this; + } + + if (UnknownFieldSet.getDefaultInstance().equals(unknownFieldsOrBuilder)) { + unknownFieldsOrBuilder = unknownFields; + onChanged(); + return (BuilderType) this; + } + + getUnknownFieldSetBuilder().mergeFrom(unknownFields); + onChanged(); + return (BuilderType) this; } @@ -817,7 +842,50 @@ public boolean isInitialized() { @Override public final UnknownFieldSet getUnknownFields() { - return unknownFields; + if (unknownFieldsOrBuilder instanceof UnknownFieldSet) { + return (UnknownFieldSet) unknownFieldsOrBuilder; + } else { + return ((UnknownFieldSet.Builder) unknownFieldsOrBuilder).buildPartial(); + } + } + + /** + * Called by generated subclasses to parse an unknown field. + * + * @return {@code true} unless the tag is an end-group tag. + */ + protected boolean parseUnknownField( + CodedInputStream input, ExtensionRegistryLite extensionRegistry, int tag) + throws IOException { + if (input.shouldDiscardUnknownFields()) { + return input.skipField(tag); + } + return getUnknownFieldSetBuilder().mergeFieldFrom(tag, input); + } + + /** Called by generated subclasses to add to the unknown field set. */ + protected final void mergeUnknownLengthDelimitedField(int number, ByteString bytes) { + getUnknownFieldSetBuilder().mergeLengthDelimitedField(number, bytes); + } + + /** Called by generated subclasses to add to the unknown field set. */ + protected final void mergeUnknownVarintField(int number, int value) { + getUnknownFieldSetBuilder().mergeVarintField(number, value); + } + + @Override + protected UnknownFieldSet.Builder getUnknownFieldSetBuilder() { + if (unknownFieldsOrBuilder instanceof UnknownFieldSet) { + unknownFieldsOrBuilder = ((UnknownFieldSet) unknownFieldsOrBuilder).toBuilder(); + } + onChanged(); + return (UnknownFieldSet.Builder) unknownFieldsOrBuilder; + } + + @Override + protected void setUnknownFieldSetBuilder(UnknownFieldSet.Builder builder) { + unknownFieldsOrBuilder = builder; + onChanged(); } /** @@ -1609,7 +1677,7 @@ protected boolean extensionsAreInitialized() { private FieldSet buildExtensions() { return extensions == null ? (FieldSet) FieldSet.emptySet() - : extensions.build(); + : extensions.buildPartial(); } @Override @@ -1815,6 +1883,20 @@ protected final void mergeExtensionFields(final ExtendableMessage other) { } } + @Override + protected boolean parseUnknownField( + CodedInputStream input, ExtensionRegistryLite extensionRegistry, int tag) + throws IOException { + ensureExtensionsIsMutable(); + return MessageReflection.mergeFieldFrom( + input, + input.shouldDiscardUnknownFields() ? null : getUnknownFieldSetBuilder(), + extensionRegistry, + getDescriptorForType(), + new MessageReflection.ExtensionBuilderAdapter(extensions), + tag); + } + private void verifyContainingType(final FieldDescriptor field) { if (field.getContainingType() != getDescriptorForType()) { throw new IllegalArgumentException( diff --git a/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java b/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java index 4aea9528ac01..3a3a70f13ffa 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java +++ b/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java @@ -32,12 +32,15 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; -import java.util.TreeSet; +import java.util.TreeMap; /** Helps generate {@link String} representations of {@link MessageLite} protos. */ final class MessageLiteToString { @@ -46,6 +49,11 @@ final class MessageLiteToString { private static final String BUILDER_LIST_SUFFIX = "OrBuilderList"; private static final String MAP_SUFFIX = "Map"; private static final String BYTES_SUFFIX = "Bytes"; + private static final char[] INDENT_BUFFER = new char[80]; + + static { + Arrays.fill(INDENT_BUFFER, ' '); + } /** * Returns a {@link String} representation of the {@link MessageLite} object. The first line of @@ -73,37 +81,51 @@ private static void reflectivePrintWithIndent( // Build a map of method name to method. We're looking for methods like getFoo(), hasFoo(), // getFooList() and getFooMap() which might be useful for building an object's string // representation. - Map nameToNoArgMethod = new HashMap(); - Map nameToMethod = new HashMap(); - Set getters = new TreeSet(); + Set setters = new HashSet<>(); + Map hazzers = new HashMap<>(); + Map getters = new TreeMap<>(); for (Method method : messageLite.getClass().getDeclaredMethods()) { - nameToMethod.put(method.getName(), method); - if (method.getParameterTypes().length == 0) { - nameToNoArgMethod.put(method.getName(), method); + if (Modifier.isStatic(method.getModifiers())) { + continue; + } + if (method.getName().length() < 3) { + continue; + } - if (method.getName().startsWith("get")) { - getters.add(method.getName()); - } + if (method.getName().startsWith("set")) { + setters.add(method.getName()); + continue; + } + + if (!Modifier.isPublic(method.getModifiers())) { + continue; + } + + if (method.getParameterTypes().length != 0) { + continue; + } + + if (method.getName().startsWith("has")) { + hazzers.put(method.getName(), method); + } else if (method.getName().startsWith("get")) { + getters.put(method.getName(), method); } } - for (String getter : getters) { - String suffix = getter.startsWith("get") ? getter.substring(3) : getter; + for (Entry getter : getters.entrySet()) { + String suffix = getter.getKey().substring(3); if (suffix.endsWith(LIST_SUFFIX) && !suffix.endsWith(BUILDER_LIST_SUFFIX) // Sometimes people have fields named 'list' that aren't repeated. && !suffix.equals(LIST_SUFFIX)) { - String camelCase = - suffix.substring(0, 1).toLowerCase() - + suffix.substring(1, suffix.length() - LIST_SUFFIX.length()); // Try to reflectively get the value and toString() the field as if it were repeated. This // only works if the method names have not been proguarded out or renamed. - Method listMethod = nameToNoArgMethod.get(getter); + Method listMethod = getter.getValue(); if (listMethod != null && listMethod.getReturnType().equals(List.class)) { printField( buffer, indent, - camelCaseToSnakeCase(camelCase), + suffix.substring(0, suffix.length() - LIST_SUFFIX.length()), GeneratedMessageLite.invokeOrDie(listMethod, messageLite)); continue; } @@ -111,12 +133,9 @@ private static void reflectivePrintWithIndent( if (suffix.endsWith(MAP_SUFFIX) // Sometimes people have fields named 'map' that aren't maps. && !suffix.equals(MAP_SUFFIX)) { - String camelCase = - suffix.substring(0, 1).toLowerCase() - + suffix.substring(1, suffix.length() - MAP_SUFFIX.length()); // Try to reflectively get the value and toString() the field as if it were a map. This only // works if the method names have not been proguarded out or renamed. - Method mapMethod = nameToNoArgMethod.get(getter); + Method mapMethod = getter.getValue(); if (mapMethod != null && mapMethod.getReturnType().equals(Map.class) // Skip the deprecated getter method with no prefix "Map" when the field name ends with @@ -127,29 +146,25 @@ private static void reflectivePrintWithIndent( printField( buffer, indent, - camelCaseToSnakeCase(camelCase), + suffix.substring(0, suffix.length() - MAP_SUFFIX.length()), GeneratedMessageLite.invokeOrDie(mapMethod, messageLite)); continue; } } - Method setter = nameToMethod.get("set" + suffix); - if (setter == null) { + if (!setters.contains("set" + suffix)) { continue; } if (suffix.endsWith(BYTES_SUFFIX) - && nameToNoArgMethod.containsKey( - "get" + suffix.substring(0, suffix.length() - "Bytes".length()))) { + && getters.containsKey("get" + suffix.substring(0, suffix.length() - "Bytes".length()))) { // Heuristic to skip bytes based accessors for string fields. continue; } - String camelCase = suffix.substring(0, 1).toLowerCase() + suffix.substring(1); - // Try to reflectively get the value and toString() the field as if it were optional. This // only works if the method names have not been proguarded out or renamed. - Method getMethod = nameToNoArgMethod.get("get" + suffix); - Method hasMethod = nameToNoArgMethod.get("has" + suffix); + Method getMethod = getter.getValue(); + Method hasMethod = hazzers.get("has" + suffix); // TODO(dweis): Fix proto3 semantics. if (getMethod != null) { Object value = GeneratedMessageLite.invokeOrDie(getMethod, messageLite); @@ -159,7 +174,7 @@ private static void reflectivePrintWithIndent( : (Boolean) GeneratedMessageLite.invokeOrDie(hasMethod, messageLite); // TODO(dweis): This doesn't stop printing oneof case twice: value and enum style. if (hasValue) { - printField(buffer, indent, camelCaseToSnakeCase(camelCase), value); + printField(buffer, indent, suffix, value); } continue; } @@ -215,10 +230,10 @@ private static boolean isDefaultValue(Object o) { * * @param buffer the buffer to write to * @param indent the number of spaces the proto should be indented by - * @param name the field name (in lower underscore case) + * @param name the field name (in PascalCase) * @param object the object value of the field */ - static final void printField(StringBuilder buffer, int indent, String name, Object object) { + static void printField(StringBuilder buffer, int indent, String name, Object object) { if (object instanceof List) { List list = (List) object; for (Object entry : list) { @@ -235,10 +250,8 @@ static final void printField(StringBuilder buffer, int indent, String name, Obje } buffer.append('\n'); - for (int i = 0; i < indent; i++) { - buffer.append(' '); - } - buffer.append(name); + indent(indent, buffer); + buffer.append(pascalCaseToSnakeCase(name)); if (object instanceof String) { buffer.append(": \"").append(TextFormatEscaper.escapeText((String) object)).append('"'); @@ -248,9 +261,7 @@ static final void printField(StringBuilder buffer, int indent, String name, Obje buffer.append(" {"); reflectivePrintWithIndent((GeneratedMessageLite) object, buffer, indent + 2); buffer.append("\n"); - for (int i = 0; i < indent; i++) { - buffer.append(' '); - } + indent(indent, buffer); buffer.append("}"); } else if (object instanceof Map.Entry) { buffer.append(" {"); @@ -258,19 +269,33 @@ static final void printField(StringBuilder buffer, int indent, String name, Obje printField(buffer, indent + 2, "key", entry.getKey()); printField(buffer, indent + 2, "value", entry.getValue()); buffer.append("\n"); - for (int i = 0; i < indent; i++) { - buffer.append(' '); - } + indent(indent, buffer); buffer.append("}"); } else { buffer.append(": ").append(object.toString()); } } - private static final String camelCaseToSnakeCase(String camelCase) { + private static void indent(int indent, StringBuilder buffer) { + while (indent > 0) { + int partialIndent = indent; + if (partialIndent > INDENT_BUFFER.length) { + partialIndent = INDENT_BUFFER.length; + } + buffer.append(INDENT_BUFFER, 0, partialIndent); + indent -= partialIndent; + } + } + + private static String pascalCaseToSnakeCase(String pascalCase) { + if (pascalCase.isEmpty()) { + return pascalCase; + } + StringBuilder builder = new StringBuilder(); - for (int i = 0; i < camelCase.length(); i++) { - char ch = camelCase.charAt(i); + builder.append(Character.toLowerCase(pascalCase.charAt(0))); + for (int i = 1; i < pascalCase.length(); i++) { + char ch = pascalCase.charAt(i); if (Character.isUpperCase(ch)) { builder.append("_"); } diff --git a/java/core/src/main/java/com/google/protobuf/MessageReflection.java b/java/core/src/main/java/com/google/protobuf/MessageReflection.java index 6741e1cb9cdf..13a6b8de5e59 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageReflection.java +++ b/java/core/src/main/java/com/google/protobuf/MessageReflection.java @@ -30,6 +30,7 @@ package com.google.protobuf; +import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import java.io.IOException; import java.util.ArrayList; @@ -323,6 +324,34 @@ Object parseMessageFromBytes( Message defaultInstance) throws IOException; + /** + * Read the given group field from the wire, merging with the existing field if it is already + * present. + * + *

For extensions, defaultInstance must be specified. For regular fields, defaultInstance can + * be null. + */ + void mergeGroup( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException; + + /** + * Read the given message field from the wire, merging with the existing field if it is already + * present. + * + *

For extensions, defaultInstance must be specified. For regular fields, defaultInstance can + * be null. + */ + void mergeMessage( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException; + /** Returns the UTF8 validation level for the field. */ WireFormat.Utf8Validation getUtf8Validation(Descriptors.FieldDescriptor descriptor); @@ -349,6 +378,7 @@ MergeTarget newEmptyTargetForField( static class BuilderAdapter implements MergeTarget { private final Message.Builder builder; + private boolean hasNestedBuilders = true; @Override public Descriptors.Descriptor getDescriptorForType() { @@ -364,6 +394,17 @@ public Object getField(Descriptors.FieldDescriptor field) { return builder.getField(field); } + private Message.Builder getFieldBuilder(Descriptors.FieldDescriptor field) { + if (hasNestedBuilders) { + try { + return builder.getFieldBuilder(field); + } catch (UnsupportedOperationException e) { + hasNestedBuilders = false; + } + } + return null; + } + @Override public boolean hasField(Descriptors.FieldDescriptor field) { return builder.hasField(field); @@ -371,6 +412,12 @@ public boolean hasField(Descriptors.FieldDescriptor field) { @Override public MergeTarget setField(Descriptors.FieldDescriptor field, Object value) { + if (!field.isRepeated() && value instanceof MessageLite.Builder) { + if (value != getFieldBuilder(field)) { + builder.setField(field, ((MessageLite.Builder) value).buildPartial()); + } + return this; + } builder.setField(field, value); return this; } @@ -384,12 +431,18 @@ public MergeTarget clearField(Descriptors.FieldDescriptor field) { @Override public MergeTarget setRepeatedField( Descriptors.FieldDescriptor field, int index, Object value) { + if (value instanceof MessageLite.Builder) { + value = ((MessageLite.Builder) value).buildPartial(); + } builder.setRepeatedField(field, index, value); return this; } @Override public MergeTarget addRepeatedField(Descriptors.FieldDescriptor field, Object value) { + if (value instanceof MessageLite.Builder) { + value = ((MessageLite.Builder) value).buildPartial(); + } builder.addRepeatedField(field, value); return this; } @@ -499,15 +552,88 @@ public Object parseMessageFromBytes( return subBuilder.buildPartial(); } + @Override + public void mergeGroup( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException { + if (!field.isRepeated()) { + Message.Builder subBuilder; + if (hasField(field)) { + subBuilder = getFieldBuilder(field); + if (subBuilder != null) { + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + return; + } else { + subBuilder = newMessageFieldInstance(field, defaultInstance); + subBuilder.mergeFrom((Message) getField(field)); + } + } else { + subBuilder = newMessageFieldInstance(field, defaultInstance); + } + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + Object unused = setField(field, subBuilder.buildPartial()); + } else { + Message.Builder subBuilder = newMessageFieldInstance(field, defaultInstance); + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + Object unused = addRepeatedField(field, subBuilder.buildPartial()); + } + } + + @Override + public void mergeMessage( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + Descriptors.FieldDescriptor field, + Message defaultInstance) + throws IOException { + if (!field.isRepeated()) { + Message.Builder subBuilder; + if (hasField(field)) { + subBuilder = getFieldBuilder(field); + if (subBuilder != null) { + input.readMessage(subBuilder, extensionRegistry); + return; + } else { + subBuilder = newMessageFieldInstance(field, defaultInstance); + subBuilder.mergeFrom((Message) getField(field)); + } + } else { + subBuilder = newMessageFieldInstance(field, defaultInstance); + } + input.readMessage(subBuilder, extensionRegistry); + Object unused = setField(field, subBuilder.buildPartial()); + } else { + Message.Builder subBuilder = newMessageFieldInstance(field, defaultInstance); + input.readMessage(subBuilder, extensionRegistry); + Object unused = addRepeatedField(field, subBuilder.buildPartial()); + } + } + + private Message.Builder newMessageFieldInstance( + FieldDescriptor field, Message defaultInstance) { + // When default instance is not null. The field is an extension field. + if (defaultInstance != null) { + return defaultInstance.newBuilderForType(); + } else { + return builder.newBuilderForField(field); + } + } + @Override public MergeTarget newMergeTargetForField( Descriptors.FieldDescriptor field, Message defaultInstance) { Message.Builder subBuilder; - if (defaultInstance != null) { - subBuilder = defaultInstance.newBuilderForType(); - } else { - subBuilder = builder.newBuilderForField(field); + if (!field.isRepeated() && hasField(field)) { + subBuilder = getFieldBuilder(field); + if (subBuilder != null) { + return new BuilderAdapter(subBuilder); + } } + + subBuilder = newMessageFieldInstance(field, defaultInstance); if (!field.isRepeated()) { Message originalMessage = (Message) getField(field); if (originalMessage != null) { @@ -543,7 +669,7 @@ public WireFormat.Utf8Validation getUtf8Validation(Descriptors.FieldDescriptor d @Override public Object finish() { - return builder.buildPartial(); + return builder; } } @@ -665,6 +791,276 @@ public Object parseMessage( return subBuilder.buildPartial(); } + @Override + public void mergeGroup( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException { + if (!field.isRepeated()) { + if (hasField(field)) { + MessageLite.Builder current = ((MessageLite) getField(field)).toBuilder(); + input.readGroup(field.getNumber(), current, extensionRegistry); + Object unused = setField(field, current.buildPartial()); + return; + } + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + Object unused = setField(field, subBuilder.buildPartial()); + } else { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + Object unused = addRepeatedField(field, subBuilder.buildPartial()); + } + } + + @Override + public void mergeMessage( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException { + if (!field.isRepeated()) { + if (hasField(field)) { + MessageLite.Builder current = ((MessageLite) getField(field)).toBuilder(); + input.readMessage(current, extensionRegistry); + Object unused = setField(field, current.buildPartial()); + return; + } + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readMessage(subBuilder, extensionRegistry); + Object unused = setField(field, subBuilder.buildPartial()); + } else { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readMessage(subBuilder, extensionRegistry); + Object unused = addRepeatedField(field, subBuilder.buildPartial()); + } + } + + @Override + public Object parseMessageFromBytes( + ByteString bytes, + ExtensionRegistryLite registry, + Descriptors.FieldDescriptor field, + Message defaultInstance) + throws IOException { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + if (!field.isRepeated()) { + Message originalMessage = (Message) getField(field); + if (originalMessage != null) { + subBuilder.mergeFrom(originalMessage); + } + } + subBuilder.mergeFrom(bytes, registry); + return subBuilder.buildPartial(); + } + + @Override + public MergeTarget newMergeTargetForField( + Descriptors.FieldDescriptor descriptor, Message defaultInstance) { + throw new UnsupportedOperationException("newMergeTargetForField() called on FieldSet object"); + } + + @Override + public MergeTarget newEmptyTargetForField( + Descriptors.FieldDescriptor descriptor, Message defaultInstance) { + throw new UnsupportedOperationException("newEmptyTargetForField() called on FieldSet object"); + } + + @Override + public WireFormat.Utf8Validation getUtf8Validation(Descriptors.FieldDescriptor descriptor) { + if (descriptor.needsUtf8Check()) { + return WireFormat.Utf8Validation.STRICT; + } + // TODO(b/248145492): support lazy strings for ExtesnsionSet. + return WireFormat.Utf8Validation.LOOSE; + } + + @Override + public Object finish() { + throw new UnsupportedOperationException("finish() called on FieldSet object"); + } + } + + static class ExtensionBuilderAdapter implements MergeTarget { + + private final FieldSet.Builder extensions; + + ExtensionBuilderAdapter(FieldSet.Builder extensions) { + this.extensions = extensions; + } + + @Override + public Descriptors.Descriptor getDescriptorForType() { + throw new UnsupportedOperationException("getDescriptorForType() called on FieldSet object"); + } + + @Override + public Object getField(Descriptors.FieldDescriptor field) { + return extensions.getField(field); + } + + @Override + public boolean hasField(Descriptors.FieldDescriptor field) { + return extensions.hasField(field); + } + + @Override + public MergeTarget setField(Descriptors.FieldDescriptor field, Object value) { + extensions.setField(field, value); + return this; + } + + @Override + public MergeTarget clearField(Descriptors.FieldDescriptor field) { + extensions.clearField(field); + return this; + } + + @Override + public MergeTarget setRepeatedField( + Descriptors.FieldDescriptor field, int index, Object value) { + extensions.setRepeatedField(field, index, value); + return this; + } + + @Override + public MergeTarget addRepeatedField(Descriptors.FieldDescriptor field, Object value) { + extensions.addRepeatedField(field, value); + return this; + } + + @Override + public boolean hasOneof(Descriptors.OneofDescriptor oneof) { + return false; + } + + @Override + public MergeTarget clearOneof(Descriptors.OneofDescriptor oneof) { + // Nothing to clear. + return this; + } + + @Override + public Descriptors.FieldDescriptor getOneofFieldDescriptor(Descriptors.OneofDescriptor oneof) { + return null; + } + + @Override + public ContainerType getContainerType() { + return ContainerType.EXTENSION_SET; + } + + @Override + public ExtensionRegistry.ExtensionInfo findExtensionByName( + ExtensionRegistry registry, String name) { + return registry.findImmutableExtensionByName(name); + } + + @Override + public ExtensionRegistry.ExtensionInfo findExtensionByNumber( + ExtensionRegistry registry, Descriptors.Descriptor containingType, int fieldNumber) { + return registry.findImmutableExtensionByNumber(containingType, fieldNumber); + } + + @Override + public Object parseGroup( + CodedInputStream input, + ExtensionRegistryLite registry, + Descriptors.FieldDescriptor field, + Message defaultInstance) + throws IOException { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + if (!field.isRepeated()) { + Message originalMessage = (Message) getField(field); + if (originalMessage != null) { + subBuilder.mergeFrom(originalMessage); + } + } + input.readGroup(field.getNumber(), subBuilder, registry); + return subBuilder.buildPartial(); + } + + @Override + public Object parseMessage( + CodedInputStream input, + ExtensionRegistryLite registry, + Descriptors.FieldDescriptor field, + Message defaultInstance) + throws IOException { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + if (!field.isRepeated()) { + Message originalMessage = (Message) getField(field); + if (originalMessage != null) { + subBuilder.mergeFrom(originalMessage); + } + } + input.readMessage(subBuilder, registry); + return subBuilder.buildPartial(); + } + + @Override + public void mergeGroup( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException { + if (!field.isRepeated()) { + if (hasField(field)) { + Object fieldOrBuilder = extensions.getFieldAllowBuilders(field); + MessageLite.Builder subBuilder; + if (fieldOrBuilder instanceof MessageLite.Builder) { + subBuilder = (MessageLite.Builder) fieldOrBuilder; + } else { + subBuilder = ((MessageLite) fieldOrBuilder).toBuilder(); + extensions.setField(field, subBuilder); + } + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + return; + } + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + Object unused = setField(field, subBuilder); + } else { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readGroup(field.getNumber(), subBuilder, extensionRegistry); + Object unused = addRepeatedField(field, subBuilder.buildPartial()); + } + } + + @Override + public void mergeMessage( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + FieldDescriptor field, + Message defaultInstance) + throws IOException { + if (!field.isRepeated()) { + if (hasField(field)) { + Object fieldOrBuilder = extensions.getFieldAllowBuilders(field); + MessageLite.Builder subBuilder; + if (fieldOrBuilder instanceof MessageLite.Builder) { + subBuilder = (MessageLite.Builder) fieldOrBuilder; + } else { + subBuilder = ((MessageLite) fieldOrBuilder).toBuilder(); + extensions.setField(field, subBuilder); + } + input.readMessage(subBuilder, extensionRegistry); + return; + } + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readMessage(subBuilder, extensionRegistry); + Object unused = setField(field, subBuilder); + } else { + Message.Builder subBuilder = defaultInstance.newBuilderForType(); + input.readMessage(subBuilder, extensionRegistry); + Object unused = addRepeatedField(field, subBuilder.buildPartial()); + } + } + @Override public Object parseMessageFromBytes( ByteString bytes, @@ -700,7 +1096,7 @@ public WireFormat.Utf8Validation getUtf8Validation(Descriptors.FieldDescriptor d if (descriptor.needsUtf8Check()) { return WireFormat.Utf8Validation.STRICT; } - // TODO(liujisi): support lazy strings for ExtesnsionSet. + // TODO(b/248145492): support lazy strings for ExtesnsionSet. return WireFormat.Utf8Validation.LOOSE; } @@ -829,13 +1225,13 @@ static boolean mergeFieldFrom( switch (field.getType()) { case GROUP: { - value = target.parseGroup(input, extensionRegistry, field, defaultInstance); - break; + target.mergeGroup(input, extensionRegistry, field, defaultInstance); + return true; } case MESSAGE: { - value = target.parseMessage(input, extensionRegistry, field, defaultInstance); - break; + target.mergeMessage(input, extensionRegistry, field, defaultInstance); + return true; } case ENUM: final int rawValue = input.readEnum(); @@ -870,6 +1266,29 @@ static boolean mergeFieldFrom( return true; } + /** Read a message from the given input stream into the provided target and UnknownFieldSet. */ + static void mergeMessageFrom( + Message.Builder target, + UnknownFieldSet.Builder unknownFields, + CodedInputStream input, + ExtensionRegistryLite extensionRegistry) + throws IOException { + BuilderAdapter builderAdapter = new BuilderAdapter(target); + Descriptor descriptorForType = target.getDescriptorForType(); + while (true) { + final int tag = input.readTag(); + if (tag == 0) { + break; + } + + if (!mergeFieldFrom( + input, unknownFields, extensionRegistry, descriptorForType, builderAdapter, tag)) { + // end group tag + break; + } + } + } + /** Called by {@code #mergeFieldFrom()} to parse a MessageSet extension into MergeTarget. */ private static void mergeMessageSetExtensionFromCodedStream( CodedInputStream input, diff --git a/java/core/src/main/java/com/google/protobuf/MessageSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSchema.java index 33c8e914b24e..8f873c1ef049 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageSchema.java +++ b/java/core/src/main/java/com/google/protobuf/MessageSchema.java @@ -42,7 +42,6 @@ import static com.google.protobuf.ArrayDecoders.decodeFixed64List; import static com.google.protobuf.ArrayDecoders.decodeFloat; import static com.google.protobuf.ArrayDecoders.decodeFloatList; -import static com.google.protobuf.ArrayDecoders.decodeGroupField; import static com.google.protobuf.ArrayDecoders.decodeGroupList; import static com.google.protobuf.ArrayDecoders.decodeMessageField; import static com.google.protobuf.ArrayDecoders.decodeMessageList; @@ -66,6 +65,8 @@ import static com.google.protobuf.ArrayDecoders.decodeVarint32List; import static com.google.protobuf.ArrayDecoders.decodeVarint64; import static com.google.protobuf.ArrayDecoders.decodeVarint64List; +import static com.google.protobuf.ArrayDecoders.mergeGroupField; +import static com.google.protobuf.ArrayDecoders.mergeMessageField; import static com.google.protobuf.ArrayDecoders.skipField; import com.google.protobuf.ArrayDecoders.Registers; @@ -1176,6 +1177,7 @@ public int hashCode(T message) { @Override public void mergeFrom(T message, T other) { + checkMutable(message); if (other == null) { throw new NullPointerException(); } @@ -1374,47 +1376,83 @@ private void mergeSingleField(T message, T other, int pos) { } } - private void mergeMessage(T message, T other, int pos) { + private void mergeMessage(T targetParent, T sourceParent, int pos) { + if (!isFieldPresent(sourceParent, pos)) { + return; + } + final int typeAndOffset = typeAndOffsetAt(pos); final long offset = offset(typeAndOffset); - if (!isFieldPresent(other, pos)) { + final Object source = UNSAFE.getObject(sourceParent, offset); + if (source == null) { + throw new IllegalStateException( + "Source subfield " + numberAt(pos) + " is present but null: " + sourceParent); + } + + final Schema fieldSchema = getMessageFieldSchema(pos); + if (!isFieldPresent(targetParent, pos)) { + if (!isMutable(source)) { + // Can safely share source if it is immutable + UNSAFE.putObject(targetParent, offset, source); + } else { + // Make a safetey copy of source + final Object copyOfSource = fieldSchema.newInstance(); + fieldSchema.mergeFrom(copyOfSource, source); + UNSAFE.putObject(targetParent, offset, copyOfSource); + } + setFieldPresent(targetParent, pos); return; } - Object mine = UnsafeUtil.getObject(message, offset); - Object theirs = UnsafeUtil.getObject(other, offset); - if (mine != null && theirs != null) { - Object merged = Internal.mergeMessage(mine, theirs); - UnsafeUtil.putObject(message, offset, merged); - setFieldPresent(message, pos); - } else if (theirs != null) { - UnsafeUtil.putObject(message, offset, theirs); - setFieldPresent(message, pos); + // Sub-message is present, merge from source + Object target = UNSAFE.getObject(targetParent, offset); + if (!isMutable(target)) { + Object newInstance = fieldSchema.newInstance(); + fieldSchema.mergeFrom(newInstance, target); + UNSAFE.putObject(targetParent, offset, newInstance); + target = newInstance; } + fieldSchema.mergeFrom(target, source); } - private void mergeOneofMessage(T message, T other, int pos) { - int typeAndOffset = typeAndOffsetAt(pos); + private void mergeOneofMessage(T targetParent, T sourceParent, int pos) { int number = numberAt(pos); - long offset = offset(typeAndOffset); - - if (!isOneofPresent(other, number, pos)) { + if (!isOneofPresent(sourceParent, number, pos)) { return; } - Object mine = null; - if (isOneofPresent(message, number, pos)) { - mine = UnsafeUtil.getObject(message, offset); + + long offset = offset(typeAndOffsetAt(pos)); + final Object source = UNSAFE.getObject(sourceParent, offset); + if (source == null) { + throw new IllegalStateException( + "Source subfield " + numberAt(pos) + " is present but null: " + sourceParent); } - Object theirs = UnsafeUtil.getObject(other, offset); - if (mine != null && theirs != null) { - Object merged = Internal.mergeMessage(mine, theirs); - UnsafeUtil.putObject(message, offset, merged); - setOneofPresent(message, number, pos); - } else if (theirs != null) { - UnsafeUtil.putObject(message, offset, theirs); - setOneofPresent(message, number, pos); + + final Schema fieldSchema = getMessageFieldSchema(pos); + if (!isOneofPresent(targetParent, number, pos)) { + if (!isMutable(source)) { + // Can safely share source if it is immutable + UNSAFE.putObject(targetParent, offset, source); + } else { + // Make a safety copy of theirs + final Object copyOfSource = fieldSchema.newInstance(); + fieldSchema.mergeFrom(copyOfSource, source); + UNSAFE.putObject(targetParent, offset, copyOfSource); + } + setOneofPresent(targetParent, number, pos); + return; } + + // Sub-message is present, merge from source + Object target = UNSAFE.getObject(targetParent, offset); + if (!isMutable(target)) { + Object newInstance = fieldSchema.newInstance(); + fieldSchema.mergeFrom(newInstance, target); + UNSAFE.putObject(targetParent, offset, newInstance); + target = newInstance; + } + fieldSchema.mergeFrom(target, source); } @Override @@ -3853,6 +3891,7 @@ public void mergeFrom(T message, Reader reader, ExtensionRegistryLite extensionR if (extensionRegistry == null) { throw new NullPointerException(); } + checkMutable(message); mergeFromHelper(unknownFieldSchema, extensionSchema, message, reader, extensionRegistry); } @@ -3889,6 +3928,7 @@ private > void mergeFromHelper( } unknownFields = extensionSchema.parseExtension( + message, reader, extension, extensionRegistry, @@ -3955,21 +3995,10 @@ private > void mergeFromHelper( break; case 9: { // MESSAGE: - if (isFieldPresent(message, pos)) { - Object mergedResult = - Internal.mergeMessage( - UnsafeUtil.getObject(message, offset(typeAndOffset)), - reader.readMessageBySchemaWithCheck( - (Schema) getMessageFieldSchema(pos), extensionRegistry)); - UnsafeUtil.putObject(message, offset(typeAndOffset), mergedResult); - } else { - UnsafeUtil.putObject( - message, - offset(typeAndOffset), - reader.readMessageBySchemaWithCheck( - (Schema) getMessageFieldSchema(pos), extensionRegistry)); - setFieldPresent(message, pos); - } + final MessageLite current = (MessageLite) mutableMessageFieldForMerge(message, pos); + reader.mergeMessageField( + current, (Schema) getMessageFieldSchema(pos), extensionRegistry); + storeMessageField(message, pos, current); break; } case 10: // BYTES: @@ -3990,7 +4019,7 @@ private > void mergeFromHelper( } else { unknownFields = SchemaUtil.storeUnknownEnum( - number, enumValue, unknownFields, unknownFieldSchema); + message, number, enumValue, unknownFields, unknownFieldSchema); } break; } @@ -4012,21 +4041,10 @@ private > void mergeFromHelper( break; case 17: { // GROUP: - if (isFieldPresent(message, pos)) { - Object mergedResult = - Internal.mergeMessage( - UnsafeUtil.getObject(message, offset(typeAndOffset)), - reader.readGroupBySchemaWithCheck( - (Schema) getMessageFieldSchema(pos), extensionRegistry)); - UnsafeUtil.putObject(message, offset(typeAndOffset), mergedResult); - } else { - UnsafeUtil.putObject( - message, - offset(typeAndOffset), - reader.readGroupBySchemaWithCheck( - (Schema) getMessageFieldSchema(pos), extensionRegistry)); - setFieldPresent(message, pos); - } + final MessageLite current = (MessageLite) mutableMessageFieldForMerge(message, pos); + reader.mergeGroupField( + current, (Schema) getMessageFieldSchema(pos), extensionRegistry); + storeMessageField(message, pos, current); break; } case 18: // DOUBLE_LIST: @@ -4089,6 +4107,7 @@ private > void mergeFromHelper( reader.readEnumList(enumList); unknownFields = SchemaUtil.filterUnknownEnumList( + message, number, enumList, getEnumFieldVerifier(pos), @@ -4155,6 +4174,7 @@ private > void mergeFromHelper( reader.readEnumList(enumList); unknownFields = SchemaUtil.filterUnknownEnumList( + message, number, enumList, getEnumFieldVerifier(pos), @@ -4235,24 +4255,15 @@ private > void mergeFromHelper( readString(message, typeAndOffset, reader); setOneofPresent(message, number, pos); break; - case 60: // ONEOF_MESSAGE: - if (isOneofPresent(message, number, pos)) { - Object mergedResult = - Internal.mergeMessage( - UnsafeUtil.getObject(message, offset(typeAndOffset)), - reader.readMessageBySchemaWithCheck( - getMessageFieldSchema(pos), extensionRegistry)); - UnsafeUtil.putObject(message, offset(typeAndOffset), mergedResult); - } else { - UnsafeUtil.putObject( - message, - offset(typeAndOffset), - reader.readMessageBySchemaWithCheck( - getMessageFieldSchema(pos), extensionRegistry)); - setFieldPresent(message, pos); + case 60: + { // ONEOF_MESSAGE: + final MessageLite current = + (MessageLite) mutableOneofMessageFieldForMerge(message, number, pos); + reader.mergeMessageField( + current, (Schema) getMessageFieldSchema(pos), extensionRegistry); + storeOneofMessageField(message, number, pos, current); + break; } - setOneofPresent(message, number, pos); - break; case 61: // ONEOF_BYTES: UnsafeUtil.putObject(message, offset(typeAndOffset), reader.readBytes()); setOneofPresent(message, number, pos); @@ -4272,7 +4283,7 @@ private > void mergeFromHelper( } else { unknownFields = SchemaUtil.storeUnknownEnum( - number, enumValue, unknownFields, unknownFieldSchema); + message, number, enumValue, unknownFields, unknownFieldSchema); } break; } @@ -4296,17 +4307,19 @@ private > void mergeFromHelper( message, offset(typeAndOffset), Long.valueOf(reader.readSInt64())); setOneofPresent(message, number, pos); break; - case 68: // ONEOF_GROUP: - UnsafeUtil.putObject( - message, - offset(typeAndOffset), - reader.readGroupBySchemaWithCheck(getMessageFieldSchema(pos), extensionRegistry)); - setOneofPresent(message, number, pos); - break; + case 68: + { // ONEOF_GROUP: + final MessageLite current = + (MessageLite) mutableOneofMessageFieldForMerge(message, number, pos); + reader.mergeGroupField( + current, (Schema) getMessageFieldSchema(pos), extensionRegistry); + storeOneofMessageField(message, number, pos, current); + break; + } default: // Assume we've landed on an empty entry. Treat it as an unknown field. if (unknownFields == null) { - unknownFields = unknownFieldSchema.newBuilder(); + unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { return; @@ -4333,7 +4346,8 @@ private > void mergeFromHelper( } finally { for (int i = checkInitializedCount; i < repeatedFieldOffsetStart; i++) { unknownFields = - filterMapUnknownEnumValues(message, intArray[i], unknownFields, unknownFieldSchema); + filterMapUnknownEnumValues( + message, intArray[i], unknownFields, unknownFieldSchema, message); } if (unknownFields != null) { unknownFieldSchema.setBuilderToMessage(message, unknownFields); @@ -4343,6 +4357,8 @@ private > void mergeFromHelper( @SuppressWarnings("ReferenceEquality") static UnknownFieldSetLite getMutableUnknownFields(Object message) { + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. UnknownFieldSetLite unknownFields = ((GeneratedMessageLite) message).unknownFields; if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) { unknownFields = UnknownFieldSetLite.newInstance(); @@ -4603,24 +4619,13 @@ private int parseRepeatedField( } else { break; } - UnknownFieldSetLite unknownFields = ((GeneratedMessageLite) message).unknownFields; - if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) { - // filterUnknownEnumList() expects the unknownFields parameter to be mutable or null. - // Since we don't know yet whether there exist unknown enum values, we'd better pass - // null to it instead of allocating a mutable instance. This is also needed to be - // consistent with the behavior of generated parser/builder. - unknownFields = null; - } - unknownFields = - SchemaUtil.filterUnknownEnumList( - number, - (ProtobufList) list, - getEnumFieldVerifier(bufferPosition), - unknownFields, - (UnknownFieldSchema) unknownFieldSchema); - if (unknownFields != null) { - ((GeneratedMessageLite) message).unknownFields = unknownFields; - } + SchemaUtil.filterUnknownEnumList( + message, + number, + (ProtobufList) list, + getEnumFieldVerifier(bufferPosition), + null, + (UnknownFieldSchema) unknownFieldSchema); break; case 33: // SINT32_LIST: case 47: // SINT32_LIST_PACKED: @@ -4774,20 +4779,11 @@ private int parseOneofField( break; case 60: // ONEOF_MESSAGE: if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) { + final Object current = mutableOneofMessageFieldForMerge(message, number, bufferPosition); position = - decodeMessageField( - getMessageFieldSchema(bufferPosition), data, position, limit, registers); - final Object oldValue = - unsafe.getInt(message, oneofCaseOffset) == number - ? unsafe.getObject(message, fieldOffset) - : null; - if (oldValue == null) { - unsafe.putObject(message, fieldOffset, registers.object1); - } else { - unsafe.putObject( - message, fieldOffset, Internal.mergeMessage(oldValue, registers.object1)); - } - unsafe.putInt(message, oneofCaseOffset, number); + mergeMessageField( + current, getMessageFieldSchema(bufferPosition), data, position, limit, registers); + storeOneofMessageField(message, number, bufferPosition, current); } break; case 61: // ONEOF_BYTES: @@ -4827,21 +4823,18 @@ private int parseOneofField( break; case 68: // ONEOF_GROUP: if (wireType == WireFormat.WIRETYPE_START_GROUP) { + final Object current = mutableOneofMessageFieldForMerge(message, number, bufferPosition); final int endTag = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP; position = - decodeGroupField( - getMessageFieldSchema(bufferPosition), data, position, limit, endTag, registers); - final Object oldValue = - unsafe.getInt(message, oneofCaseOffset) == number - ? unsafe.getObject(message, fieldOffset) - : null; - if (oldValue == null) { - unsafe.putObject(message, fieldOffset, registers.object1); - } else { - unsafe.putObject( - message, fieldOffset, Internal.mergeMessage(oldValue, registers.object1)); - } - unsafe.putInt(message, oneofCaseOffset, number); + mergeGroupField( + current, + getMessageFieldSchema(bufferPosition), + data, + position, + limit, + endTag, + registers); + storeOneofMessageField(message, number, bufferPosition, current); } break; default: @@ -4878,6 +4871,7 @@ private EnumVerifier getEnumFieldVerifier(int pos) { int parseProto2Message( T message, byte[] data, int position, int limit, int endGroup, Registers registers) throws IOException { + checkMutable(message); final sun.misc.Unsafe unsafe = UNSAFE; int currentPresenceFieldOffset = NO_PRESENCE_SENTINEL; int currentPresenceField = 0; @@ -4994,18 +4988,11 @@ int parseProto2Message( break; case 9: // MESSAGE if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) { + final Object current = mutableMessageFieldForMerge(message, pos); position = - decodeMessageField( - getMessageFieldSchema(pos), data, position, limit, registers); - if ((currentPresenceField & presenceMask) == 0) { - unsafe.putObject(message, fieldOffset, registers.object1); - } else { - unsafe.putObject( - message, - fieldOffset, - Internal.mergeMessage( - unsafe.getObject(message, fieldOffset), registers.object1)); - } + mergeMessageField( + current, getMessageFieldSchema(pos), data, position, limit, registers); + storeMessageField(message, pos, current); currentPresenceField |= presenceMask; continue; } @@ -5054,20 +5041,18 @@ int parseProto2Message( break; case 17: // GROUP if (wireType == WireFormat.WIRETYPE_START_GROUP) { + final Object current = mutableMessageFieldForMerge(message, pos); final int endTag = (number << 3) | WireFormat.WIRETYPE_END_GROUP; position = - decodeGroupField( - getMessageFieldSchema(pos), data, position, limit, endTag, registers); - if ((currentPresenceField & presenceMask) == 0) { - unsafe.putObject(message, fieldOffset, registers.object1); - } else { - unsafe.putObject( - message, - fieldOffset, - Internal.mergeMessage( - unsafe.getObject(message, fieldOffset), registers.object1)); - } - + mergeGroupField( + current, + getMessageFieldSchema(pos), + data, + position, + limit, + endTag, + registers); + storeMessageField(message, pos, current); currentPresenceField |= presenceMask; continue; } @@ -5165,7 +5150,8 @@ int parseProto2Message( message, intArray[i], unknownFields, - (UnknownFieldSchema) unknownFieldSchema); + (UnknownFieldSchema) unknownFieldSchema, + message); } if (unknownFields != null) { ((UnknownFieldSchema) unknownFieldSchema) @@ -5183,9 +5169,65 @@ int parseProto2Message( return position; } + private Object mutableMessageFieldForMerge(T message, int pos) { + final Schema fieldSchema = getMessageFieldSchema(pos); + final long offset = offset(typeAndOffsetAt(pos)); + + // Field not present, create a new one + if (!isFieldPresent(message, pos)) { + return fieldSchema.newInstance(); + } + + // Field present, if mutable, ready to merge + final Object current = UNSAFE.getObject(message, offset); + if (isMutable(current)) { + return current; + } + + // Field present but immutable, make a new mutable copy + final Object newMessage = fieldSchema.newInstance(); + if (current != null) { + fieldSchema.mergeFrom(newMessage, current); + } + return newMessage; + } + + private void storeMessageField(T message, int pos, Object field) { + UNSAFE.putObject(message, offset(typeAndOffsetAt(pos)), field); + setFieldPresent(message, pos); + } + + private Object mutableOneofMessageFieldForMerge(T message, int fieldNumber, int pos) { + final Schema fieldSchema = getMessageFieldSchema(pos); + + // Field not present, create it and mark it present + if (!isOneofPresent(message, fieldNumber, pos)) { + return fieldSchema.newInstance(); + } + + // Field present, if mutable, ready to merge + final Object current = UNSAFE.getObject(message, offset(typeAndOffsetAt(pos))); + if (isMutable(current)) { + return current; + } + + // Field present but immutable, make a new mutable copy + final Object newMessage = fieldSchema.newInstance(); + if (current != null) { + fieldSchema.mergeFrom(newMessage, current); + } + return newMessage; + } + + private void storeOneofMessageField(T message, int fieldNumber, int pos, Object field) { + UNSAFE.putObject(message, offset(typeAndOffsetAt(pos)), field); + setOneofPresent(message, fieldNumber, pos); + } + /** Parses a proto3 message and returns the limit if parsing is successful. */ private int parseProto3Message( T message, byte[] data, int position, int limit, Registers registers) throws IOException { + checkMutable(message); final sun.misc.Unsafe unsafe = UNSAFE; int currentPresenceFieldOffset = NO_PRESENCE_SENTINEL; int currentPresenceField = 0; @@ -5307,16 +5349,11 @@ private int parseProto3Message( break; case 9: // MESSAGE: if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) { + final Object current = mutableMessageFieldForMerge(message, pos); position = - decodeMessageField( - getMessageFieldSchema(pos), data, position, limit, registers); - final Object oldValue = unsafe.getObject(message, fieldOffset); - if (oldValue == null) { - unsafe.putObject(message, fieldOffset, registers.object1); - } else { - unsafe.putObject( - message, fieldOffset, Internal.mergeMessage(oldValue, registers.object1)); - } + mergeMessageField( + current, getMessageFieldSchema(pos), data, position, limit, registers); + storeMessageField(message, pos, current); currentPresenceField |= presenceMask; continue; } @@ -5447,18 +5484,73 @@ public void mergeFrom(T message, byte[] data, int position, int limit, Registers @Override public void makeImmutable(T message) { - // Make all repeated/map fields immutable. - for (int i = checkInitializedCount; i < repeatedFieldOffsetStart; i++) { - long offset = offset(typeAndOffsetAt(intArray[i])); - Object mapField = UnsafeUtil.getObject(message, offset); - if (mapField == null) { - continue; - } - UnsafeUtil.putObject(message, offset, mapFieldSchema.toImmutable(mapField)); + if (!isMutable(message)) { + return; + } + + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. + if (message instanceof GeneratedMessageLite) { + GeneratedMessageLite generatedMessage = ((GeneratedMessageLite) message); + generatedMessage.clearMemoizedSerializedSize(); + generatedMessage.clearMemoizedHashCode(); + generatedMessage.markImmutable(); } - final int length = intArray.length; - for (int i = repeatedFieldOffsetStart; i < length; i++) { - listFieldSchema.makeImmutableListAt(message, intArray[i]); + + final int bufferLength = buffer.length; + for (int pos = 0; pos < bufferLength; pos += INTS_PER_FIELD) { + final int typeAndOffset = typeAndOffsetAt(pos); + final long offset = offset(typeAndOffset); + switch (type(typeAndOffset)) { + case 17: // GROUP + case 9: // MESSAGE + if (isFieldPresent(message, pos)) { + getMessageFieldSchema(pos).makeImmutable(UNSAFE.getObject(message, offset)); + } + break; + case 18: // DOUBLE_LIST: + case 19: // FLOAT_LIST: + case 20: // INT64_LIST: + case 21: // UINT64_LIST: + case 22: // INT32_LIST: + case 23: // FIXED64_LIST: + case 24: // FIXED32_LIST: + case 25: // BOOL_LIST: + case 26: // STRING_LIST: + case 27: // MESSAGE_LIST: + case 28: // BYTES_LIST: + case 29: // UINT32_LIST: + case 30: // ENUM_LIST: + case 31: // SFIXED32_LIST: + case 32: // SFIXED64_LIST: + case 33: // SINT32_LIST: + case 34: // SINT64_LIST: + case 35: // DOUBLE_LIST_PACKED: + case 36: // FLOAT_LIST_PACKED: + case 37: // INT64_LIST_PACKED: + case 38: // UINT64_LIST_PACKED: + case 39: // INT32_LIST_PACKED: + case 40: // FIXED64_LIST_PACKED: + case 41: // FIXED32_LIST_PACKED: + case 42: // BOOL_LIST_PACKED: + case 43: // UINT32_LIST_PACKED: + case 44: // ENUM_LIST_PACKED: + case 45: // SFIXED32_LIST_PACKED: + case 46: // SFIXED64_LIST_PACKED: + case 47: // SINT32_LIST_PACKED: + case 48: // SINT64_LIST_PACKED: + case 49: // GROUP_LIST: + listFieldSchema.makeImmutableListAt(message, offset); + break; + case 50: // MAP: + { + Object mapField = UNSAFE.getObject(message, offset); + if (mapField != null) { + UNSAFE.putObject(message, offset, mapFieldSchema.toImmutable(mapField)); + } + } + break; + } } unknownFieldSchema.makeImmutable(message); if (hasExtensions) { @@ -5495,8 +5587,12 @@ private final void mergeMap( extensionRegistry); } - private final UB filterMapUnknownEnumValues( - Object message, int pos, UB unknownFields, UnknownFieldSchema unknownFieldSchema) { + private UB filterMapUnknownEnumValues( + Object message, + int pos, + UB unknownFields, + UnknownFieldSchema unknownFieldSchema, + Object containerMessage) { int fieldNumber = numberAt(pos); long offset = offset(typeAndOffsetAt(pos)); Object mapField = UnsafeUtil.getObject(message, offset); @@ -5511,25 +5607,32 @@ private final UB filterMapUnknownEnumValues( // Filter unknown enum values. unknownFields = filterUnknownEnumMap( - pos, fieldNumber, mapData, enumVerifier, unknownFields, unknownFieldSchema); + pos, + fieldNumber, + mapData, + enumVerifier, + unknownFields, + unknownFieldSchema, + containerMessage); return unknownFields; } @SuppressWarnings("unchecked") - private final UB filterUnknownEnumMap( + private UB filterUnknownEnumMap( int pos, int number, Map mapData, EnumVerifier enumVerifier, UB unknownFields, - UnknownFieldSchema unknownFieldSchema) { + UnknownFieldSchema unknownFieldSchema, + Object containerMessage) { Metadata metadata = (Metadata) mapFieldSchema.forMapMetadata(getMapFieldDefaultEntry(pos)); for (Iterator> it = mapData.entrySet().iterator(); it.hasNext(); ) { Map.Entry entry = it.next(); if (!enumVerifier.isInRange((Integer) entry.getValue())) { if (unknownFields == null) { - unknownFields = unknownFieldSchema.newBuilder(); + unknownFields = unknownFieldSchema.getBuilderFromMessage(containerMessage); } int entrySize = MapEntryLite.computeSerializedSize(metadata, entry.getKey(), entry.getValue()); @@ -5746,6 +5849,28 @@ private static long offset(int value) { return value & OFFSET_MASK; } + private static boolean isMutable(Object message) { + if (message == null) { + return false; + } + + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. + if (message instanceof GeneratedMessageLite) { + return ((GeneratedMessageLite) message).isMutable(); + } + + // For other types, we'll assume this is true because that's what was + // happening before we started checking. + return true; + } + + private static void checkMutable(Object message) { + if (!isMutable(message)) { + throw new IllegalArgumentException("Mutating immutable message: " + message); + } + } + private static double doubleAt(T message, long offset) { return UnsafeUtil.getDouble(message, offset); } diff --git a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java index 187dc8b8a5e1..eae93b912f6c 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java +++ b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java @@ -61,7 +61,13 @@ static MessageSetSchema newSchema( @SuppressWarnings("unchecked") @Override public T newInstance() { - return (T) defaultInstance.newBuilderForType().buildPartial(); + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. + if (defaultInstance instanceof GeneratedMessageLite) { + return (T) ((GeneratedMessageLite) defaultInstance).newMutableInstance(); + } else { + return (T) defaultInstance.newBuilderForType().buildPartial(); + } } @Override @@ -132,6 +138,8 @@ private void writeUnknownFieldsHelper( public void mergeFrom( T message, byte[] data, int position, int limit, ArrayDecoders.Registers registers) throws IOException { + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. UnknownFieldSetLite unknownFields = ((GeneratedMessageLite) message).unknownFields; if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) { unknownFields = UnknownFieldSetLite.newInstance(); @@ -180,9 +188,12 @@ public void mergeFrom( if (wireType == WireFormat.WIRETYPE_VARINT) { position = ArrayDecoders.decodeVarint32(data, position, registers); typeId = registers.int1; + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and + // handle this better. extension = - (GeneratedMessageLite.GeneratedExtension) extensionSchema - .findExtensionByNumber(registers.extensionRegistry, defaultInstance, typeId); + (GeneratedMessageLite.GeneratedExtension) + extensionSchema.findExtensionByNumber( + registers.extensionRegistry, defaultInstance, typeId); continue; } break; diff --git a/java/core/src/main/java/com/google/protobuf/NewInstanceSchemaLite.java b/java/core/src/main/java/com/google/protobuf/NewInstanceSchemaLite.java index 9b922667638d..00cfe3b6c7f0 100644 --- a/java/core/src/main/java/com/google/protobuf/NewInstanceSchemaLite.java +++ b/java/core/src/main/java/com/google/protobuf/NewInstanceSchemaLite.java @@ -33,7 +33,8 @@ final class NewInstanceSchemaLite implements NewInstanceSchema { @Override public Object newInstance(Object defaultInstance) { - return ((GeneratedMessageLite) defaultInstance) - .dynamicMethod(GeneratedMessageLite.MethodToInvoke.NEW_MUTABLE_INSTANCE); + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. + return ((GeneratedMessageLite) defaultInstance).newMutableInstance(); } } diff --git a/java/core/src/main/java/com/google/protobuf/Reader.java b/java/core/src/main/java/com/google/protobuf/Reader.java index 705096f2d80e..b99ee43d4fd6 100644 --- a/java/core/src/main/java/com/google/protobuf/Reader.java +++ b/java/core/src/main/java/com/google/protobuf/Reader.java @@ -158,6 +158,14 @@ T readMessageBySchemaWithCheck(Schema schema, ExtensionRegistryLite exten T readGroupBySchemaWithCheck(Schema schema, ExtensionRegistryLite extensionRegistry) throws IOException; + /** Read a message field from the wire format and merge the results into the given target. */ + void mergeMessageField(T target, Schema schema, ExtensionRegistryLite extensionRegistry) + throws IOException; + + /** Read a group field from the wire format and merge the results into the given target. */ + void mergeGroupField(T target, Schema schema, ExtensionRegistryLite extensionRegistry) + throws IOException; + /** * Reads and returns the next field of type {@code BYTES} and advances the reader to the next * field. diff --git a/java/core/src/main/java/com/google/protobuf/SchemaUtil.java b/java/core/src/main/java/com/google/protobuf/SchemaUtil.java index 4c8bb06f041c..0e4c42c2dcae 100644 --- a/java/core/src/main/java/com/google/protobuf/SchemaUtil.java +++ b/java/core/src/main/java/com/google/protobuf/SchemaUtil.java @@ -59,6 +59,8 @@ private SchemaUtil() {} * GeneratedMessageLite}. */ public static void requireGeneratedMessage(Class messageType) { + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle this + // better. if (!GeneratedMessageLite.class.isAssignableFrom(messageType) && GENERATED_MESSAGE_CLASS != null && !GENERATED_MESSAGE_CLASS.isAssignableFrom(messageType)) { @@ -808,6 +810,8 @@ public static boolean shouldUseTableSwitch(int lo, int hi, int numFields) { private static Class getGeneratedMessageClass() { try { + // TODO(b/248560713) decide if we're keeping support for Full in schema classes and handle + // this better. return Class.forName("com.google.protobuf.GeneratedMessageV3"); } catch (Throwable e) { return null; @@ -901,6 +905,7 @@ static void mergeUnknownFields( /** Filters unrecognized enum values in a list. */ static UB filterUnknownEnumList( + Object containerMessage, int number, List enumList, EnumLiteMap enumMap, @@ -921,7 +926,9 @@ static UB filterUnknownEnumList( } ++writePos; } else { - unknownFields = storeUnknownEnum(number, enumValue, unknownFields, unknownFieldSchema); + unknownFields = + storeUnknownEnum( + containerMessage, number, enumValue, unknownFields, unknownFieldSchema); } } if (writePos != size) { @@ -931,7 +938,9 @@ static UB filterUnknownEnumList( for (Iterator it = enumList.iterator(); it.hasNext(); ) { int enumValue = it.next(); if (enumMap.findValueByNumber(enumValue) == null) { - unknownFields = storeUnknownEnum(number, enumValue, unknownFields, unknownFieldSchema); + unknownFields = + storeUnknownEnum( + containerMessage, number, enumValue, unknownFields, unknownFieldSchema); it.remove(); } } @@ -941,6 +950,7 @@ static UB filterUnknownEnumList( /** Filters unrecognized enum values in a list. */ static UB filterUnknownEnumList( + Object containerMessage, int number, List enumList, EnumVerifier enumVerifier, @@ -961,7 +971,9 @@ static UB filterUnknownEnumList( } ++writePos; } else { - unknownFields = storeUnknownEnum(number, enumValue, unknownFields, unknownFieldSchema); + unknownFields = + storeUnknownEnum( + containerMessage, number, enumValue, unknownFields, unknownFieldSchema); } } if (writePos != size) { @@ -971,7 +983,9 @@ static UB filterUnknownEnumList( for (Iterator it = enumList.iterator(); it.hasNext(); ) { int enumValue = it.next(); if (!enumVerifier.isInRange(enumValue)) { - unknownFields = storeUnknownEnum(number, enumValue, unknownFields, unknownFieldSchema); + unknownFields = + storeUnknownEnum( + containerMessage, number, enumValue, unknownFields, unknownFieldSchema); it.remove(); } } @@ -981,9 +995,13 @@ static UB filterUnknownEnumList( /** Stores an unrecognized enum value as an unknown value. */ static UB storeUnknownEnum( - int number, int enumValue, UB unknownFields, UnknownFieldSchema unknownFieldSchema) { + Object containerMessage, + int number, + int enumValue, + UB unknownFields, + UnknownFieldSchema unknownFieldSchema) { if (unknownFields == null) { - unknownFields = unknownFieldSchema.newBuilder(); + unknownFields = unknownFieldSchema.getBuilderFromMessage(containerMessage); } unknownFieldSchema.addVarint(unknownFields, number, enumValue); return unknownFields; diff --git a/java/core/src/main/java/com/google/protobuf/TextFormat.java b/java/core/src/main/java/com/google/protobuf/TextFormat.java index e781df333d10..c6c0359aa6b9 100644 --- a/java/core/src/main/java/com/google/protobuf/TextFormat.java +++ b/java/core/src/main/java/com/google/protobuf/TextFormat.java @@ -594,7 +594,7 @@ private void printFieldValue( case MESSAGE: case GROUP: - print((Message) value, generator); + print((MessageOrBuilder) value, generator); break; } } diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java index b2cb7be4baef..37a14e214e4f 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java @@ -388,7 +388,7 @@ final void printWithIndent(StringBuilder buffer, int indent) { // Package private for unsafe experimental runtime. void storeField(int tag, Object value) { checkMutable(); - ensureCapacity(); + ensureCapacity(count + 1); tags[count] = tag; objects[count] = value; @@ -396,13 +396,23 @@ void storeField(int tag, Object value) { } /** Ensures that our arrays are long enough to store more metadata. */ - private void ensureCapacity() { - if (count == tags.length) { - int increment = count < (MIN_CAPACITY / 2) ? MIN_CAPACITY : count >> 1; - int newLength = count + increment; + private void ensureCapacity(int minCapacity) { + if (minCapacity > this.tags.length) { + // Increase by at least 50% + int newCapacity = count + count / 2; + + // Or new capacity if higher + if (newCapacity < minCapacity) { + newCapacity = minCapacity; + } + + // And never less than MIN_CAPACITY + if (newCapacity < MIN_CAPACITY) { + newCapacity = MIN_CAPACITY; + } - tags = Arrays.copyOf(tags, newLength); - objects = Arrays.copyOf(objects, newLength); + this.tags = Arrays.copyOf(this.tags, newCapacity); + this.objects = Arrays.copyOf(this.objects, newCapacity); } } @@ -487,4 +497,18 @@ private UnknownFieldSetLite mergeFrom(final CodedInputStream input) throws IOExc } return this; } + + UnknownFieldSetLite mergeFrom(UnknownFieldSetLite other) { + if (other.equals(getDefaultInstance())) { + return this; + } + + checkMutable(); + int newCount = this.count + other.count; + ensureCapacity(newCount); + System.arraycopy(other.tags, 0, tags, this.count, other.count); + System.arraycopy(other.objects, 0, objects, this.count, other.count); + this.count = newCount; + return this; + } } diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLiteSchema.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLiteSchema.java index ffd7232308c4..2cfdeca2450a 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLiteSchema.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLiteSchema.java @@ -122,10 +122,14 @@ void writeAsMessageSetTo(UnknownFieldSetLite fields, Writer writer) throws IOExc } @Override - UnknownFieldSetLite merge(UnknownFieldSetLite message, UnknownFieldSetLite other) { - return other.equals(UnknownFieldSetLite.getDefaultInstance()) - ? message - : UnknownFieldSetLite.mutableCopyOf(message, other); + UnknownFieldSetLite merge(UnknownFieldSetLite target, UnknownFieldSetLite source) { + if (UnknownFieldSetLite.getDefaultInstance().equals(source)) { + return target; + } + if (UnknownFieldSetLite.getDefaultInstance().equals(target)) { + return UnknownFieldSetLite.mutableCopyOf(target, source); + } + return target.mergeFrom(source); } @Override diff --git a/java/lite/src/test/java/com/google/protobuf/LiteTest.java b/java/lite/src/test/java/com/google/protobuf/LiteTest.java index f2ce4614c797..ec62480227df 100644 --- a/java/lite/src/test/java/com/google/protobuf/LiteTest.java +++ b/java/lite/src/test/java/com/google/protobuf/LiteTest.java @@ -50,15 +50,6 @@ import com.google.protobuf.UnittestLite.TestAllTypesLiteOrBuilder; import com.google.protobuf.UnittestLite.TestHugeFieldNumbersLite; import com.google.protobuf.UnittestLite.TestNestedExtensionLite; -import map_lite_test.MapTestProto.TestMap; -import map_lite_test.MapTestProto.TestMap.MessageValue; -import protobuf_unittest.NestedExtensionLite; -import protobuf_unittest.NonNestedExtensionLite; -import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.Bar; -import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.BarPrime; -import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.Foo; -import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.TestOneofEquals; -import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.TestRecursiveOneof; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -71,6 +62,15 @@ import java.util.Iterator; import java.util.List; import junit.framework.TestCase; +import map_lite_test.MapTestProto.TestMap; +import map_lite_test.MapTestProto.TestMap.MessageValue; +import protobuf_unittest.NestedExtensionLite; +import protobuf_unittest.NonNestedExtensionLite; +import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.Bar; +import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.BarPrime; +import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.Foo; +import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.TestOneofEquals; +import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.TestRecursiveOneof; /** * Test lite runtime. @@ -183,16 +183,24 @@ public void testMemoization() throws Exception { TestAllExtensionsLite message = TestUtilLite.getAllLiteExtensionsSet(); // Test serialized size is memoized - message.memoizedSerializedSize = -1; + assertEquals( + GeneratedMessageLite.UNINITIALIZED_SERIALIZED_SIZE, + message.getMemoizedSerializedSize()); int size = message.getSerializedSize(); assertTrue(size > 0); - assertEquals(size, message.memoizedSerializedSize); + assertEquals(size, message.getMemoizedSerializedSize()); + message.clearMemoizedSerializedSize(); + assertEquals( + GeneratedMessageLite.UNINITIALIZED_SERIALIZED_SIZE, + message.getMemoizedSerializedSize()); // Test hashCode is memoized - assertEquals(0, message.memoizedHashCode); + assertTrue(message.hashCodeIsNotMemoized()); int hashCode = message.hashCode(); - assertTrue(hashCode != 0); - assertEquals(hashCode, message.memoizedHashCode); + assertFalse(message.hashCodeIsNotMemoized()); + assertEquals(hashCode, message.getMemoizedHashCode()); + message.clearMemoizedHashCode(); + assertTrue(message.hashCodeIsNotMemoized()); // Test isInitialized is memoized Field memo = message.getClass().getDeclaredField("memoizedIsInitialized"); diff --git a/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java b/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java index 352376e01509..2aa0916f0625 100644 --- a/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java +++ b/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java @@ -35,6 +35,7 @@ import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.FieldMask; +import com.google.protobuf.GeneratedMessage; import com.google.protobuf.Message; import java.util.ArrayList; import java.util.List; @@ -304,9 +305,12 @@ private static void merge( // so we don't create unnecessary empty messages. continue; } - String childPath = path.isEmpty() ? entry.getKey() : path + "." + entry.getKey(); - Message.Builder childBuilder = ((Message) destination.getField(field)).toBuilder(); - merge(entry.getValue(), childPath, (Message) source.getField(field), childBuilder, options); + // This is a mess because of java proto API 1 still hanging around. + Message.Builder childBuilder = + destination instanceof GeneratedMessage.Builder + ? destination.getFieldBuilder(field) + : ((Message) destination.getField(field)).toBuilder(); + merge(entry.getValue(), path, (Message) source.getField(field), childBuilder, options); destination.setField(field, childBuilder.buildPartial()); continue; } diff --git a/src/google/protobuf/compiler/java/java_enum_field.cc b/src/google/protobuf/compiler/java/java_enum_field.cc index 9a0799ee4666..1a0b8e0f901c 100644 --- a/src/google/protobuf/compiler/java/java_enum_field.cc +++ b/src/google/protobuf/compiler/java/java_enum_field.cc @@ -112,13 +112,6 @@ void SetEnumVariables(const FieldDescriptor* descriptor, int messageBitIndex, (*variables)["set_mutable_bit_builder"] = GenerateSetBit(builderBitIndex); (*variables)["clear_mutable_bit_builder"] = GenerateClearBit(builderBitIndex); - // For repeated fields, one bit is used for whether the array is immutable - // in the parsing constructor. - (*variables)["get_mutable_bit_parser"] = - GenerateGetBitMutableLocal(builderBitIndex); - (*variables)["set_mutable_bit_parser"] = - GenerateSetBitMutableLocal(builderBitIndex); - (*variables)["get_has_field_bit_from_local"] = GenerateGetBitFromLocal(builderBitIndex); (*variables)["set_has_field_bit_to_local"] = @@ -316,32 +309,26 @@ void ImmutableEnumFieldGenerator::GenerateBuildingCode( printer->Print(variables_, "result.$name$_ = $name$_;\n"); } -void ImmutableEnumFieldGenerator::GenerateParsingCode( +void ImmutableEnumFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { if (SupportUnknownEnumValue(descriptor_->file())) { printer->Print(variables_, - "int rawValue = input.readEnum();\n" - "$set_has_field_bit_message$\n" - "$name$_ = rawValue;\n"); + "$name$_ = input.readEnum();\n" + "$set_has_field_bit_builder$\n"); } else { printer->Print(variables_, - "int rawValue = input.readEnum();\n" - " @SuppressWarnings(\"deprecation\")\n" - "$type$ value = $type$.$for_number$(rawValue);\n" - "if (value == null) {\n" - " unknownFields.mergeVarintField($number$, rawValue);\n" + "int tmpRaw = input.readEnum();\n" + "$type$ tmpValue =\n" + " $type$.forNumber(tmpRaw);\n" + "if (tmpValue == null) {\n" + " mergeUnknownVarintField($number$, tmpRaw);\n" "} else {\n" - " $set_has_field_bit_message$\n" - " $name$_ = rawValue;\n" + " $name$_ = tmpRaw;\n" + " $set_has_field_bit_builder$\n" "}\n"); } } -void ImmutableEnumFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - // noop for enums -} - void ImmutableEnumFieldGenerator::GenerateSerializationCode( io::Printer* printer) const { printer->Print(variables_, @@ -502,6 +489,11 @@ void ImmutableEnumOneofFieldGenerator::GenerateBuilderMembers( printer->Annotate("{", "}", descriptor_); } +void ImmutableEnumOneofFieldGenerator::GenerateBuilderClearCode( + io::Printer* printer) const { + // No-op: Enum fields in oneofs are correctly cleared by clearing the oneof +} + void ImmutableEnumOneofFieldGenerator::GenerateBuildingCode( io::Printer* printer) const { printer->Print(variables_, @@ -522,7 +514,7 @@ void ImmutableEnumOneofFieldGenerator::GenerateMergingCode( } } -void ImmutableEnumOneofFieldGenerator::GenerateParsingCode( +void ImmutableEnumOneofFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { if (SupportUnknownEnumValue(descriptor_->file())) { printer->Print(variables_, @@ -532,10 +524,10 @@ void ImmutableEnumOneofFieldGenerator::GenerateParsingCode( } else { printer->Print(variables_, "int rawValue = input.readEnum();\n" - "@SuppressWarnings(\"deprecation\")\n" - "$type$ value = $type$.$for_number$(rawValue);\n" + "$type$ value =\n" + " $type$.forNumber(rawValue);\n" "if (value == null) {\n" - " unknownFields.mergeVarintField($number$, rawValue);\n" + " mergeUnknownVarintField($number$, rawValue);\n" "} else {\n" " $set_oneof_case_message$;\n" " $oneof_name$_ = rawValue;\n" @@ -914,36 +906,29 @@ void RepeatedImmutableEnumFieldGenerator::GenerateBuildingCode( "result.$name$_ = $name$_;\n"); } -void RepeatedImmutableEnumFieldGenerator::GenerateParsingCode( +void RepeatedImmutableEnumFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { // Read and store the enum if (SupportUnknownEnumValue(descriptor_->file())) { printer->Print(variables_, - "int rawValue = input.readEnum();\n" - "if (!$get_mutable_bit_parser$) {\n" - " $name$_ = new java.util.ArrayList();\n" - " $set_mutable_bit_parser$;\n" - "}\n" - "$name$_.add(rawValue);\n"); + "int tmpRaw = input.readEnum();\n" + "ensure$capitalized_name$IsMutable();\n" + "$name$_.add(tmpRaw);\n"); } else { - printer->Print( - variables_, - "int rawValue = input.readEnum();\n" - "@SuppressWarnings(\"deprecation\")\n" - "$type$ value = $type$.$for_number$(rawValue);\n" - "if (value == null) {\n" - " unknownFields.mergeVarintField($number$, rawValue);\n" - "} else {\n" - " if (!$get_mutable_bit_parser$) {\n" - " $name$_ = new java.util.ArrayList();\n" - " $set_mutable_bit_parser$;\n" - " }\n" - " $name$_.add(rawValue);\n" - "}\n"); + printer->Print(variables_, + "int tmpRaw = input.readEnum();\n" + "$type$ tmpValue =\n" + " $type$.forNumber(tmpRaw);\n" + "if (tmpValue == null) {\n" + " mergeUnknownVarintField($number$, tmpRaw);\n" + "} else {\n" + " ensure$capitalized_name$IsMutable();\n" + " $name$_.add(tmpRaw);\n" + "}\n"); } } -void RepeatedImmutableEnumFieldGenerator::GenerateParsingCodeFromPacked( +void RepeatedImmutableEnumFieldGenerator::GenerateBuilderParsingCodeFromPacked( io::Printer* printer) const { // Wrap GenerateParsingCode's contents with a while loop. @@ -953,7 +938,7 @@ void RepeatedImmutableEnumFieldGenerator::GenerateParsingCodeFromPacked( "while(input.getBytesUntilLimit() > 0) {\n"); printer->Indent(); - GenerateParsingCode(printer); + GenerateBuilderParsingCode(printer); printer->Outdent(); printer->Print(variables_, @@ -961,15 +946,6 @@ void RepeatedImmutableEnumFieldGenerator::GenerateParsingCodeFromPacked( "input.popLimit(oldLimit);\n"); } -void RepeatedImmutableEnumFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - printer->Print( - variables_, - "if ($get_mutable_bit_parser$) {\n" - " $name$_ = java.util.Collections.unmodifiableList($name$_);\n" - "}\n"); -} - void RepeatedImmutableEnumFieldGenerator::GenerateSerializationCode( io::Printer* printer) const { if (descriptor_->is_packed()) { diff --git a/src/google/protobuf/compiler/java/java_enum_field.h b/src/google/protobuf/compiler/java/java_enum_field.h index 95c7db578f2f..318e013d5741 100644 --- a/src/google/protobuf/compiler/java/java_enum_field.h +++ b/src/google/protobuf/compiler/java/java_enum_field.h @@ -64,24 +64,24 @@ class ImmutableEnumFieldGenerator : public ImmutableFieldGenerator { // implements ImmutableFieldGenerator // --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; protected: const FieldDescriptor* descriptor_; @@ -99,15 +99,16 @@ class ImmutableEnumOneofFieldGenerator : public ImmutableEnumFieldGenerator { Context* context); ~ImmutableEnumOneofFieldGenerator(); - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; private: GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ImmutableEnumOneofFieldGenerator); @@ -121,25 +122,26 @@ class RepeatedImmutableEnumFieldGenerator : public ImmutableFieldGenerator { ~RepeatedImmutableEnumFieldGenerator(); // implements ImmutableFieldGenerator --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingCodeFromPacked(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCodeFromPacked( + io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; private: const FieldDescriptor* descriptor_; diff --git a/src/google/protobuf/compiler/java/java_field.cc b/src/google/protobuf/compiler/java/java_field.cc index 2f775a68a66e..229b3b3a676b 100644 --- a/src/google/protobuf/compiler/java/java_field.cc +++ b/src/google/protobuf/compiler/java/java_field.cc @@ -185,7 +185,7 @@ static inline void ReportUnexpectedPackedFieldsCall(io::Printer* printer) { // but this method should be overridden. // - This FieldGenerator doesn't support packing, and this method // should never have been called. - GOOGLE_LOG(FATAL) << "GenerateParsingCodeFromPacked() " + GOOGLE_LOG(FATAL) << "GenerateBuilderParsingCodeFromPacked() " << "called on field generator that does not support packing."; } @@ -193,7 +193,7 @@ static inline void ReportUnexpectedPackedFieldsCall(io::Printer* printer) { ImmutableFieldGenerator::~ImmutableFieldGenerator() {} -void ImmutableFieldGenerator::GenerateParsingCodeFromPacked( +void ImmutableFieldGenerator::GenerateBuilderParsingCodeFromPacked( io::Printer* printer) const { ReportUnexpectedPackedFieldsCall(printer); } diff --git a/src/google/protobuf/compiler/java/java_field.h b/src/google/protobuf/compiler/java/java_field.h index df6c38d75226..11ec277824b4 100644 --- a/src/google/protobuf/compiler/java/java_field.h +++ b/src/google/protobuf/compiler/java/java_field.h @@ -77,9 +77,8 @@ class ImmutableFieldGenerator { virtual void GenerateBuilderClearCode(io::Printer* printer) const = 0; virtual void GenerateMergingCode(io::Printer* printer) const = 0; virtual void GenerateBuildingCode(io::Printer* printer) const = 0; - virtual void GenerateParsingCode(io::Printer* printer) const = 0; - virtual void GenerateParsingCodeFromPacked(io::Printer* printer) const; - virtual void GenerateParsingDoneCode(io::Printer* printer) const = 0; + virtual void GenerateBuilderParsingCode(io::Printer* printer) const = 0; + virtual void GenerateBuilderParsingCodeFromPacked(io::Printer* printer) const; virtual void GenerateSerializationCode(io::Printer* printer) const = 0; virtual void GenerateSerializedSizeCode(io::Printer* printer) const = 0; virtual void GenerateFieldBuilderInitializationCode( diff --git a/src/google/protobuf/compiler/java/java_map_field.cc b/src/google/protobuf/compiler/java/java_map_field.cc index 5db199d38fa4..1fb8f1e630ca 100644 --- a/src/google/protobuf/compiler/java/java_map_field.cc +++ b/src/google/protobuf/compiler/java/java_map_field.cc @@ -138,13 +138,6 @@ void SetMessageVariables(const FieldDescriptor* descriptor, int messageBitIndex, descriptor->options().deprecated() ? "@java.lang.Deprecated " : ""; (*variables)["on_changed"] = "onChanged();"; - // For repeated fields, one bit is used for whether the array is immutable - // in the parsing constructor. - (*variables)["get_mutable_bit_parser"] = - GenerateGetBitMutableLocal(builderBitIndex); - (*variables)["set_mutable_bit_parser"] = - GenerateSetBitMutableLocal(builderBitIndex); - (*variables)["default_entry"] = (*variables)["capitalized_name"] + "DefaultEntryHolder.defaultEntry"; (*variables)["map_field_parameter"] = (*variables)["default_entry"]; @@ -681,27 +674,19 @@ void ImmutableMapFieldGenerator::GenerateBuildingCode( "result.$name$_.makeImmutable();\n"); } -void ImmutableMapFieldGenerator::GenerateParsingCode( +void ImmutableMapFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { - printer->Print(variables_, - "if (!$get_mutable_bit_parser$) {\n" - " $name$_ = com.google.protobuf.MapField.newMapField(\n" - " $map_field_parameter$);\n" - " $set_mutable_bit_parser$;\n" - "}\n"); if (!SupportUnknownEnumValue(descriptor_->file()) && GetJavaType(ValueField(descriptor_)) == JAVATYPE_ENUM) { printer->Print( variables_, "com.google.protobuf.ByteString bytes = input.readBytes();\n" "com.google.protobuf.MapEntry<$type_parameters$>\n" - "$name$__ = $default_entry$.getParserForType().parseFrom(bytes);\n"); - printer->Print( - variables_, + "$name$__ = $default_entry$.getParserForType().parseFrom(bytes);\n" "if ($value_enum_type$.forNumber($name$__.getValue()) == null) {\n" - " unknownFields.mergeLengthDelimitedField($number$, bytes);\n" + " mergeUnknownLengthDelimitedField($number$, bytes);\n" "} else {\n" - " $name$_.getMutableMap().put(\n" + " internalGetMutable$capitalized_name$().getMutableMap().put(\n" " $name$__.getKey(), $name$__.getValue());\n" "}\n"); } else { @@ -710,16 +695,11 @@ void ImmutableMapFieldGenerator::GenerateParsingCode( "com.google.protobuf.MapEntry<$type_parameters$>\n" "$name$__ = input.readMessage(\n" " $default_entry$.getParserForType(), extensionRegistry);\n" - "$name$_.getMutableMap().put(\n" + "internalGetMutable$capitalized_name$().getMutableMap().put(\n" " $name$__.getKey(), $name$__.getValue());\n"); } } -void ImmutableMapFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - // Nothing to do here. -} - void ImmutableMapFieldGenerator::GenerateSerializationCode( io::Printer* printer) const { printer->Print(variables_, diff --git a/src/google/protobuf/compiler/java/java_map_field.h b/src/google/protobuf/compiler/java/java_map_field.h index 2ff1f7673e9c..4e46222f9335 100644 --- a/src/google/protobuf/compiler/java/java_map_field.h +++ b/src/google/protobuf/compiler/java/java_map_field.h @@ -46,23 +46,23 @@ class ImmutableMapFieldGenerator : public ImmutableFieldGenerator { ~ImmutableMapFieldGenerator(); // implements ImmutableFieldGenerator --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + std::string GetBoxedType() const; private: diff --git a/src/google/protobuf/compiler/java/java_message.cc b/src/google/protobuf/compiler/java/java_message.cc index f2df25f0feb6..22005bcf3106 100644 --- a/src/google/protobuf/compiler/java/java_message.cc +++ b/src/google/protobuf/compiler/java/java_message.cc @@ -52,6 +52,7 @@ #include #include #include +#include #include #include #include @@ -368,6 +369,7 @@ void ImmutableMessageGenerator::Generate(io::Printer* printer) { "}\n" "\n"); + // TODO(b/248149118): Remove this superfluous override. printer->Print( "@java.lang.Override\n" "public final com.google.protobuf.UnknownFieldSet\n" @@ -375,10 +377,6 @@ void ImmutableMessageGenerator::Generate(io::Printer* printer) { " return this.unknownFields;\n" "}\n"); - if (context_->HasGeneratedMethods(descriptor_)) { - GenerateParsingConstructor(printer); - } - GenerateDescriptorMethods(printer); // Nested types @@ -627,9 +625,9 @@ void ImmutableMessageGenerator::GenerateMessageSerializationMethods( } if (descriptor_->options().message_set_wire_format()) { - printer->Print("unknownFields.writeAsMessageSetTo(output);\n"); + printer->Print("getUnknownFields().writeAsMessageSetTo(output);\n"); } else { - printer->Print("unknownFields.writeTo(output);\n"); + printer->Print("getUnknownFields().writeTo(output);\n"); } printer->Outdent(); @@ -658,9 +656,10 @@ void ImmutableMessageGenerator::GenerateMessageSerializationMethods( } if (descriptor_->options().message_set_wire_format()) { - printer->Print("size += unknownFields.getSerializedSizeAsMessageSet();\n"); + printer->Print( + "size += getUnknownFields().getSerializedSizeAsMessageSet();\n"); } else { - printer->Print("size += unknownFields.getSerializedSize();\n"); + printer->Print("size += getUnknownFields().getSerializedSize();\n"); } printer->Print( @@ -1057,7 +1056,8 @@ void ImmutableMessageGenerator::GenerateEqualsAndHashCode( // false for non-canonical ordering when running in LITE_RUNTIME but it's // the best we can do. printer->Print( - "if (!unknownFields.equals(other.unknownFields)) return false;\n"); + "if (!getUnknownFields().equals(other.getUnknownFields())) return " + "false;\n"); if (descriptor_->extension_range_count() > 0) { printer->Print( "if (!getExtensionFields().equals(other.getExtensionFields()))\n" @@ -1131,7 +1131,7 @@ void ImmutableMessageGenerator::GenerateEqualsAndHashCode( printer->Print("hash = hashFields(hash, getExtensionFields());\n"); } - printer->Print("hash = (29 * hash) + unknownFields.hashCode();\n"); + printer->Print("hash = (29 * hash) + getUnknownFields().hashCode();\n"); printer->Print( "memoizedHashCode = hash;\n" "return hash;\n"); @@ -1156,186 +1156,33 @@ void ImmutableMessageGenerator::GenerateExtensionRegistrationCode( } } -// =================================================================== -void ImmutableMessageGenerator::GenerateParsingConstructor( - io::Printer* printer) { - std::unique_ptr sorted_fields( - SortFieldsByNumber(descriptor_)); - - printer->Print( - "private $classname$(\n" - " com.google.protobuf.CodedInputStream input,\n" - " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" - " throws com.google.protobuf.InvalidProtocolBufferException {\n", - "classname", descriptor_->name()); - printer->Indent(); - - // Initialize all fields to default. - printer->Print( - "this();\n" - "if (extensionRegistry == null) {\n" - " throw new java.lang.NullPointerException();\n" - "}\n"); - - // Use builder bits to track mutable repeated fields. - int totalBuilderBits = 0; - for (int i = 0; i < descriptor_->field_count(); i++) { - const ImmutableFieldGenerator& field = - field_generators_.get(descriptor_->field(i)); - totalBuilderBits += field.GetNumBitsForBuilder(); - } - int totalBuilderInts = (totalBuilderBits + 31) / 32; - for (int i = 0; i < totalBuilderInts; i++) { - printer->Print("int mutable_$bit_field_name$ = 0;\n", "bit_field_name", - GetBitFieldName(i)); - } - - printer->Print( - "com.google.protobuf.UnknownFieldSet.Builder unknownFields =\n" - " com.google.protobuf.UnknownFieldSet.newBuilder();\n"); - - printer->Print("try {\n"); - printer->Indent(); - - printer->Print( - "boolean done = false;\n" - "while (!done) {\n"); - printer->Indent(); - - printer->Print( - "int tag = input.readTag();\n" - "switch (tag) {\n"); - printer->Indent(); - - printer->Print( - "case 0:\n" // zero signals EOF / limit reached - " done = true;\n" - " break;\n"); - - for (int i = 0; i < descriptor_->field_count(); i++) { - const FieldDescriptor* field = sorted_fields[i]; - uint32_t tag = WireFormatLite::MakeTag( - field->number(), WireFormat::WireTypeForFieldType(field->type())); - - printer->Print("case $tag$: {\n", "tag", - StrCat(static_cast(tag))); - printer->Indent(); - - field_generators_.get(field).GenerateParsingCode(printer); - - printer->Outdent(); - printer->Print( - " break;\n" - "}\n"); - - if (field->is_packable()) { - // To make packed = true wire compatible, we generate parsing code from a - // packed version of this field regardless of field->options().packed(). - uint32_t packed_tag = WireFormatLite::MakeTag( - field->number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED); - printer->Print("case $tag$: {\n", "tag", - StrCat(static_cast(packed_tag))); - printer->Indent(); - - field_generators_.get(field).GenerateParsingCodeFromPacked(printer); - - printer->Outdent(); - printer->Print( - " break;\n" - "}\n"); - } - } - - printer->Print( - "default: {\n" - " if (!parseUnknownField(\n" - " input, unknownFields, extensionRegistry, tag)) {\n" - " done = true;\n" // it's an endgroup tag - " }\n" - " break;\n" - "}\n"); - - printer->Outdent(); - printer->Outdent(); - printer->Print( - " }\n" // switch (tag) - "}\n"); // while (!done) - - printer->Outdent(); - printer->Print( - "} catch (com.google.protobuf.InvalidProtocolBufferException e) {\n" - " throw e.setUnfinishedMessage(this);\n" - "} catch (java.io.IOException e) {\n" - " throw new com.google.protobuf.InvalidProtocolBufferException(\n" - " e).setUnfinishedMessage(this);\n" - "} finally {\n"); - printer->Indent(); - - // Make repeated field list immutable. - for (int i = 0; i < descriptor_->field_count(); i++) { - const FieldDescriptor* field = sorted_fields[i]; - field_generators_.get(field).GenerateParsingDoneCode(printer); - } - - // Make unknown fields immutable. - printer->Print("this.unknownFields = unknownFields.build();\n"); - - // Make extensions immutable. - printer->Print("makeExtensionsImmutable();\n"); - - printer->Outdent(); - printer->Outdent(); - printer->Print( - " }\n" // finally - "}\n"); -} - // =================================================================== void ImmutableMessageGenerator::GenerateParser(io::Printer* printer) { printer->Print( "$visibility$ static final com.google.protobuf.Parser<$classname$>\n" - " PARSER = new com.google.protobuf.AbstractParser<$classname$>() {\n", - "visibility", - ExposePublicParser(descriptor_->file()) ? "@java.lang.Deprecated public" - : "private", - "classname", descriptor_->name()); - printer->Indent(); - printer->Print( - "@java.lang.Override\n" - "public $classname$ parsePartialFrom(\n" - " com.google.protobuf.CodedInputStream input,\n" - " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" - " throws com.google.protobuf.InvalidProtocolBufferException {\n", - "classname", descriptor_->name()); - if (context_->HasGeneratedMethods(descriptor_)) { - printer->Print(" return new $classname$(input, extensionRegistry);\n", - "classname", descriptor_->name()); - } else { - // When parsing constructor isn't generated, use builder to parse - // messages. Note, will fallback to use reflection based mergeFieldFrom() - // in AbstractMessage.Builder. - printer->Indent(); - printer->Print( - "Builder builder = newBuilder();\n" - "try {\n" - " builder.mergeFrom(input, extensionRegistry);\n" - "} catch (com.google.protobuf.InvalidProtocolBufferException e) {\n" - " throw e.setUnfinishedMessage(builder.buildPartial());\n" - "} catch (java.io.IOException e) {\n" - " throw new com.google.protobuf.InvalidProtocolBufferException(\n" - " e.getMessage()).setUnfinishedMessage(\n" - " builder.buildPartial());\n" - "}\n" - "return builder.buildPartial();\n"); - printer->Outdent(); - } - printer->Print("}\n"); - printer->Outdent(); - printer->Print( + " PARSER = new com.google.protobuf.AbstractParser<$classname$>() {\n" + " @java.lang.Override\n" + " public $classname$ parsePartialFrom(\n" + " com.google.protobuf.CodedInputStream input,\n" + " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" + " throws com.google.protobuf.InvalidProtocolBufferException {\n" + " Builder builder = newBuilder();\n" + " try {\n" + " builder.mergeFrom(input, extensionRegistry);\n" + " } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n" + " throw e.setUnfinishedMessage(builder.buildPartial());\n" + " } catch (com.google.protobuf.UninitializedMessageException e) {\n" + " throw " + "e.asInvalidProtocolBufferException().setUnfinishedMessage(builder." + "buildPartial());\n" + " } catch (java.io.IOException e) {\n" + " throw new com.google.protobuf.InvalidProtocolBufferException(e)\n" + " .setUnfinishedMessage(builder.buildPartial());\n" + " }\n" + " return builder.buildPartial();\n" + " }\n" "};\n" - "\n"); - - printer->Print( + "\n" "public static com.google.protobuf.Parser<$classname$> parser() {\n" " return PARSER;\n" "}\n" @@ -1345,6 +1192,9 @@ void ImmutableMessageGenerator::GenerateParser(io::Printer* printer) { " return PARSER;\n" "}\n" "\n", + "visibility", + ExposePublicParser(descriptor_->file()) ? "@java.lang.Deprecated public" + : "private", "classname", descriptor_->name()); } diff --git a/src/google/protobuf/compiler/java/java_message_builder.cc b/src/google/protobuf/compiler/java/java_message_builder.cc index 320852b1be97..24ea648c79f3 100644 --- a/src/google/protobuf/compiler/java/java_message_builder.cc +++ b/src/google/protobuf/compiler/java/java_message_builder.cc @@ -58,6 +58,9 @@ namespace protobuf { namespace compiler { namespace java { +using internal::WireFormat; +using internal::WireFormatLite; + namespace { std::string MapValueImmutableClassdName(const Descriptor* descriptor, ClassNameResolver* name_resolver) { @@ -285,43 +288,63 @@ void MessageBuilderGenerator::GenerateDescriptorMethods(io::Printer* printer) { void MessageBuilderGenerator::GenerateCommonBuilderMethods( io::Printer* printer) { + // Decide if we really need to have the "maybeForceBuilderInitialization()" + // method. + // TODO(b/249158148): Remove the need for this entirely + bool need_maybe_force_builder_init = false; + for (int i = 0; i < descriptor_->field_count(); i++) { + if (descriptor_->field(i)->message_type() != nullptr && + !IsRealOneof(descriptor_->field(i)) && + HasHasbit(descriptor_->field(i))) { + need_maybe_force_builder_init = true; + break; + } + } + + const char* force_builder_init = need_maybe_force_builder_init + ? " maybeForceBuilderInitialization();" + : ""; + printer->Print( "// Construct using $classname$.newBuilder()\n" "private Builder() {\n" - " maybeForceBuilderInitialization();\n" + "$force_builder_init$\n" "}\n" "\n", - "classname", name_resolver_->GetImmutableClassName(descriptor_)); + "classname", name_resolver_->GetImmutableClassName(descriptor_), + "force_builder_init", force_builder_init); printer->Print( "private Builder(\n" " com.google.protobuf.GeneratedMessage$ver$.BuilderParent parent) {\n" " super(parent);\n" - " maybeForceBuilderInitialization();\n" + "$force_builder_init$\n" "}\n", "classname", name_resolver_->GetImmutableClassName(descriptor_), "ver", - GeneratedCodeVersionSuffix()); + GeneratedCodeVersionSuffix(), "force_builder_init", force_builder_init); - printer->Print( - "private void maybeForceBuilderInitialization() {\n" - " if (com.google.protobuf.GeneratedMessage$ver$\n" - " .alwaysUseFieldBuilders) {\n", - "ver", GeneratedCodeVersionSuffix()); + if (need_maybe_force_builder_init) { + printer->Print( + "private void maybeForceBuilderInitialization() {\n" + " if (com.google.protobuf.GeneratedMessage$ver$\n" + " .alwaysUseFieldBuilders) {\n", + "ver", GeneratedCodeVersionSuffix()); - printer->Indent(); - printer->Indent(); - for (int i = 0; i < descriptor_->field_count(); i++) { - if (!IsRealOneof(descriptor_->field(i))) { - field_generators_.get(descriptor_->field(i)) - .GenerateFieldBuilderInitializationCode(printer); + printer->Indent(); + printer->Indent(); + for (int i = 0; i < descriptor_->field_count(); i++) { + if (!IsRealOneof(descriptor_->field(i))) { + field_generators_.get(descriptor_->field(i)) + .GenerateFieldBuilderInitializationCode(printer); + } } - } - printer->Outdent(); - printer->Outdent(); + printer->Outdent(); + printer->Outdent(); - printer->Print( - " }\n" - "}\n"); + printer->Print( + " }\n" + "}\n"); + } printer->Print( "@java.lang.Override\n" @@ -331,10 +354,8 @@ void MessageBuilderGenerator::GenerateCommonBuilderMethods( printer->Indent(); for (int i = 0; i < descriptor_->field_count(); i++) { - if (!IsRealOneof(descriptor_->field(i))) { - field_generators_.get(descriptor_->field(i)) - .GenerateBuilderClearCode(printer); - } + field_generators_.get(descriptor_->field(i)) + .GenerateBuilderClearCode(printer); } for (auto oneof : oneofs_) { @@ -575,7 +596,7 @@ void MessageBuilderGenerator::GenerateCommonBuilderMethods( printer->Print(" this.mergeExtensionFields(other);\n"); } - printer->Print(" this.mergeUnknownFields(other.unknownFields);\n"); + printer->Print(" this.mergeUnknownFields(other.getUnknownFields());\n"); printer->Print(" onChanged();\n"); @@ -596,20 +617,92 @@ void MessageBuilderGenerator::GenerateBuilderParsingMethods( " com.google.protobuf.CodedInputStream input,\n" " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" " throws java.io.IOException {\n" - " $classname$ parsedMessage = null;\n" + " if (extensionRegistry == null) {\n" + " throw new java.lang.NullPointerException();\n" + " }\n" " try {\n" - " parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);\n" + " boolean done = false;\n" + " while (!done) {\n" + " int tag = input.readTag();\n" + " switch (tag) {\n" + " case 0:\n" // zero signals EOF / limit reached + " done = true;\n" + " break;\n"); + printer->Indent(); // method + printer->Indent(); // try + printer->Indent(); // while + printer->Indent(); // switch + GenerateBuilderFieldParsingCases(printer); + printer->Outdent(); // switch + printer->Outdent(); // while + printer->Outdent(); // try + printer->Outdent(); // method + printer->Print( + " default: {\n" + " if (!super.parseUnknownField(input, extensionRegistry, tag)) " + "{\n" + " done = true; // was an endgroup tag\n" + " }\n" + " break;\n" + " } // default:\n" + " } // switch (tag)\n" + " } // while (!done)\n" " } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n" - " parsedMessage = ($classname$) e.getUnfinishedMessage();\n" " throw e.unwrapIOException();\n" " } finally {\n" - " if (parsedMessage != null) {\n" - " mergeFrom(parsedMessage);\n" - " }\n" - " }\n" + " onChanged();\n" + " } // finally\n" " return this;\n" - "}\n", - "classname", name_resolver_->GetImmutableClassName(descriptor_)); + "}\n"); +} + +void MessageBuilderGenerator::GenerateBuilderFieldParsingCases( + io::Printer* printer) { + std::unique_ptr sorted_fields( + SortFieldsByNumber(descriptor_)); + for (int i = 0; i < descriptor_->field_count(); i++) { + const FieldDescriptor* field = sorted_fields[i]; + GenerateBuilderFieldParsingCase(printer, field); + if (field->is_packable()) { + GenerateBuilderPackedFieldParsingCase(printer, field); + } + } +} + +void MessageBuilderGenerator::GenerateBuilderFieldParsingCase( + io::Printer* printer, const FieldDescriptor* field) { + uint32_t tag = WireFormatLite::MakeTag( + field->number(), WireFormat::WireTypeForFieldType(field->type())); + std::string tagString = StrCat(static_cast(tag)); + printer->Print("case $tag$: {\n", "tag", tagString); + printer->Indent(); + + field_generators_.get(field).GenerateBuilderParsingCode(printer); + + printer->Outdent(); + printer->Print( + " break;\n" + "} // case $tag$\n", + "tag", tagString); +} + +void MessageBuilderGenerator::GenerateBuilderPackedFieldParsingCase( + io::Printer* printer, const FieldDescriptor* field) { + // To make packed = true wire compatible, we generate parsing code from a + // packed version of this field regardless of field->options().packed(). + uint32_t tag = WireFormatLite::MakeTag( + field->number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED); + std::string tagString = StrCat(static_cast(tag)); + printer->Print("case $tag$: {\n", "tag", tagString); + printer->Indent(); + + field_generators_.get(field).GenerateBuilderParsingCodeFromPacked(printer); + + printer->Outdent(); + printer->Print( + " break;\n" + "} // case $tag$\n", + "tag", tagString); } // =================================================================== diff --git a/src/google/protobuf/compiler/java/java_message_builder.h b/src/google/protobuf/compiler/java/java_message_builder.h index fcd73b343626..96f289a838d8 100644 --- a/src/google/protobuf/compiler/java/java_message_builder.h +++ b/src/google/protobuf/compiler/java/java_message_builder.h @@ -70,6 +70,11 @@ class MessageBuilderGenerator { void GenerateCommonBuilderMethods(io::Printer* printer); void GenerateDescriptorMethods(io::Printer* printer); void GenerateBuilderParsingMethods(io::Printer* printer); + void GenerateBuilderFieldParsingCases(io::Printer* printer); + void GenerateBuilderFieldParsingCase(io::Printer* printer, + const FieldDescriptor* field); + void GenerateBuilderPackedFieldParsingCase(io::Printer* printer, + const FieldDescriptor* field); void GenerateIsInitialized(io::Printer* printer); const Descriptor* descriptor_; diff --git a/src/google/protobuf/compiler/java/java_message_field.cc b/src/google/protobuf/compiler/java/java_message_field.cc index f657c1795742..a650b770471a 100644 --- a/src/google/protobuf/compiler/java/java_message_field.cc +++ b/src/google/protobuf/compiler/java/java_message_field.cc @@ -103,13 +103,6 @@ void SetMessageVariables(const FieldDescriptor* descriptor, int messageBitIndex, (*variables)["set_mutable_bit_builder"] = GenerateSetBit(builderBitIndex); (*variables)["clear_mutable_bit_builder"] = GenerateClearBit(builderBitIndex); - // For repeated fields, one bit is used for whether the array is immutable - // in the parsing constructor. - (*variables)["get_mutable_bit_parser"] = - GenerateGetBitMutableLocal(builderBitIndex); - (*variables)["set_mutable_bit_parser"] = - GenerateSetBitMutableLocal(builderBitIndex); - (*variables)["get_has_field_bit_from_local"] = GenerateGetBitFromLocal(builderBitIndex); (*variables)["set_has_field_bit_to_local"] = @@ -457,35 +450,21 @@ void ImmutableMessageFieldGenerator::GenerateBuildingCode( } } -void ImmutableMessageFieldGenerator::GenerateParsingCode( +void ImmutableMessageFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { - printer->Print(variables_, - "$type$.Builder subBuilder = null;\n" - "if ($is_field_present_message$) {\n" - " subBuilder = $name$_.toBuilder();\n" - "}\n"); - if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) { printer->Print(variables_, - "$name$_ = input.readGroup($number$, $type$.$get_parser$,\n" - " extensionRegistry);\n"); + "input.readGroup($number$,\n" + " get$capitalized_name$FieldBuilder().getBuilder(),\n" + " extensionRegistry);\n" + "$set_has_field_bit_builder$\n"); } else { printer->Print(variables_, - "$name$_ = input.readMessage($type$.$get_parser$, " - "extensionRegistry);\n"); + "input.readMessage(\n" + " get$capitalized_name$FieldBuilder().getBuilder(),\n" + " extensionRegistry);\n" + "$set_has_field_bit_builder$\n"); } - - printer->Print(variables_, - "if (subBuilder != null) {\n" - " subBuilder.mergeFrom($name$_);\n" - " $name$_ = subBuilder.buildPartial();\n" - "}\n" - "$set_has_field_bit_message$\n"); -} - -void ImmutableMessageFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - // noop for messages. } void ImmutableMessageFieldGenerator::GenerateSerializationCode( @@ -737,6 +716,15 @@ void ImmutableMessageOneofFieldGenerator::GenerateBuilderMembers( printer->Annotate("{", "}", descriptor_); } +void ImmutableMessageOneofFieldGenerator::GenerateBuilderClearCode( + io::Printer* printer) const { + // Make sure the builder gets cleared. + printer->Print(variables_, + "if ($name$Builder_ != null) {\n" + " $name$Builder_.clear();\n" + "}\n"); +} + void ImmutableMessageOneofFieldGenerator::GenerateBuildingCode( io::Printer* printer) const { printer->Print(variables_, "if ($has_oneof_case_message$) {\n"); @@ -757,32 +745,21 @@ void ImmutableMessageOneofFieldGenerator::GenerateMergingCode( "merge$capitalized_name$(other.get$capitalized_name$());\n"); } -void ImmutableMessageOneofFieldGenerator::GenerateParsingCode( +void ImmutableMessageOneofFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { - printer->Print(variables_, - "$type$.Builder subBuilder = null;\n" - "if ($has_oneof_case_message$) {\n" - " subBuilder = (($type$) $oneof_name$_).toBuilder();\n" - "}\n"); - if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) { - printer->Print( - variables_, - "$oneof_name$_ = input.readGroup($number$, $type$.$get_parser$,\n" - " extensionRegistry);\n"); + printer->Print(variables_, + "input.readGroup($number$,\n" + " get$capitalized_name$FieldBuilder().getBuilder(),\n" + " extensionRegistry);\n" + "$set_oneof_case_message$;\n"); } else { - printer->Print( - variables_, - "$oneof_name$_ =\n" - " input.readMessage($type$.$get_parser$, extensionRegistry);\n"); + printer->Print(variables_, + "input.readMessage(\n" + " get$capitalized_name$FieldBuilder().getBuilder(),\n" + " extensionRegistry);\n" + "$set_oneof_case_message$;\n"); } - - printer->Print(variables_, - "if (subBuilder != null) {\n" - " subBuilder.mergeFrom(($type$) $oneof_name$_);\n" - " $oneof_name$_ = subBuilder.buildPartial();\n" - "}\n"); - printer->Print(variables_, "$set_oneof_case_message$;\n"); } void ImmutableMessageOneofFieldGenerator::GenerateSerializationCode( @@ -1233,10 +1210,12 @@ void RepeatedImmutableMessageFieldGenerator::GenerateInitializationCode( void RepeatedImmutableMessageFieldGenerator::GenerateBuilderClearCode( io::Printer* printer) const { PrintNestedBuilderCondition(printer, - "$name$_ = java.util.Collections.emptyList();\n" - "$clear_mutable_bit_builder$;\n", + "$name$_ = java.util.Collections.emptyList();\n", + "$name$_ = null;\n" "$name$Builder_.clear();\n"); + + printer->Print(variables_, "$clear_mutable_bit_builder$;\n"); } void RepeatedImmutableMessageFieldGenerator::GenerateMergingCode( @@ -1291,34 +1270,25 @@ void RepeatedImmutableMessageFieldGenerator::GenerateBuildingCode( "result.$name$_ = $name$Builder_.build();\n"); } -void RepeatedImmutableMessageFieldGenerator::GenerateParsingCode( +void RepeatedImmutableMessageFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { - printer->Print(variables_, - "if (!$get_mutable_bit_parser$) {\n" - " $name$_ = new java.util.ArrayList<$type$>();\n" - " $set_mutable_bit_parser$;\n" - "}\n"); - if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) { - printer->Print( - variables_, - "$name$_.add(input.readGroup($number$, $type$.$get_parser$,\n" - " extensionRegistry));\n"); + printer->Print(variables_, + "$type$ m =\n" + " input.readGroup($number$,\n" + " $type$.$get_parser$,\n" + " extensionRegistry);\n"); } else { - printer->Print( - variables_, - "$name$_.add(\n" - " input.readMessage($type$.$get_parser$, extensionRegistry));\n"); + printer->Print(variables_, + "$type$ m =\n" + " input.readMessage(\n" + " $type$.$get_parser$,\n" + " extensionRegistry);\n"); } -} - -void RepeatedImmutableMessageFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - printer->Print( - variables_, - "if ($get_mutable_bit_parser$) {\n" - " $name$_ = java.util.Collections.unmodifiableList($name$_);\n" - "}\n"); + PrintNestedBuilderCondition(printer, + "ensure$capitalized_name$IsMutable();\n" + "$name$_.add(m);\n", + "$name$Builder_.addMessage(m);\n"); } void RepeatedImmutableMessageFieldGenerator::GenerateSerializationCode( diff --git a/src/google/protobuf/compiler/java/java_message_field.h b/src/google/protobuf/compiler/java/java_message_field.h index 36fa49208cbe..07e025670392 100644 --- a/src/google/protobuf/compiler/java/java_message_field.h +++ b/src/google/protobuf/compiler/java/java_message_field.h @@ -65,24 +65,24 @@ class ImmutableMessageFieldGenerator : public ImmutableFieldGenerator { // implements ImmutableFieldGenerator // --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; protected: const FieldDescriptor* descriptor_; @@ -110,13 +110,14 @@ class ImmutableMessageOneofFieldGenerator Context* context); ~ImmutableMessageOneofFieldGenerator(); - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; private: GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ImmutableMessageOneofFieldGenerator); @@ -130,24 +131,24 @@ class RepeatedImmutableMessageFieldGenerator : public ImmutableFieldGenerator { ~RepeatedImmutableMessageFieldGenerator(); // implements ImmutableFieldGenerator --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; protected: const FieldDescriptor* descriptor_; diff --git a/src/google/protobuf/compiler/java/java_primitive_field.cc b/src/google/protobuf/compiler/java/java_primitive_field.cc index 65cc05adcfc0..21bb866acb65 100644 --- a/src/google/protobuf/compiler/java/java_primitive_field.cc +++ b/src/google/protobuf/compiler/java/java_primitive_field.cc @@ -168,13 +168,6 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor, (*variables)["set_mutable_bit_builder"] = GenerateSetBit(builderBitIndex); (*variables)["clear_mutable_bit_builder"] = GenerateClearBit(builderBitIndex); - // For repeated fields, one bit is used for whether the array is immutable - // in the parsing constructor. - (*variables)["get_mutable_bit_parser"] = - GenerateGetBitMutableLocal(builderBitIndex); - (*variables)["set_mutable_bit_parser"] = - GenerateSetBitMutableLocal(builderBitIndex); - (*variables)["get_has_field_bit_from_local"] = GenerateGetBitFromLocal(builderBitIndex); (*variables)["set_has_field_bit_to_local"] = @@ -354,16 +347,11 @@ void ImmutablePrimitiveFieldGenerator::GenerateBuildingCode( } } -void ImmutablePrimitiveFieldGenerator::GenerateParsingCode( +void ImmutablePrimitiveFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { printer->Print(variables_, - "$set_has_field_bit_message$\n" - "$name$_ = input.read$capitalized_type$();\n"); -} - -void ImmutablePrimitiveFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - // noop for primitives. + "$name$_ = input.read$capitalized_type$();\n" + "$set_has_field_bit_builder$\n"); } void ImmutablePrimitiveFieldGenerator::GenerateSerializationCode( @@ -568,6 +556,12 @@ void ImmutablePrimitiveOneofFieldGenerator::GenerateBuilderMembers( printer->Annotate("{", "}", descriptor_); } +void ImmutablePrimitiveOneofFieldGenerator::GenerateBuilderClearCode( + io::Printer* printer) const { + // No-Op: When a primitive field is in a oneof, clearing the oneof clears that + // field. +} + void ImmutablePrimitiveOneofFieldGenerator::GenerateBuildingCode( io::Printer* printer) const { printer->Print(variables_, @@ -582,7 +576,7 @@ void ImmutablePrimitiveOneofFieldGenerator::GenerateMergingCode( "set$capitalized_name$(other.get$capitalized_name$());\n"); } -void ImmutablePrimitiveOneofFieldGenerator::GenerateParsingCode( +void ImmutablePrimitiveOneofFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { printer->Print(variables_, "$set_oneof_case_message$;\n" @@ -842,38 +836,24 @@ void RepeatedImmutablePrimitiveFieldGenerator::GenerateBuildingCode( "result.$name$_ = $name$_;\n"); } -void RepeatedImmutablePrimitiveFieldGenerator::GenerateParsingCode( +void RepeatedImmutablePrimitiveFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { printer->Print(variables_, - "if (!$get_mutable_bit_parser$) {\n" - " $name$_ = $create_list$;\n" - " $set_mutable_bit_parser$;\n" - "}\n" - "$repeated_add$(input.read$capitalized_type$());\n"); + "$type$ v = input.read$capitalized_type$();\n" + "ensure$capitalized_name$IsMutable();\n" + "$repeated_add$(v);\n"); } -void RepeatedImmutablePrimitiveFieldGenerator::GenerateParsingCodeFromPacked( - io::Printer* printer) const { - printer->Print( - variables_, - "int length = input.readRawVarint32();\n" - "int limit = input.pushLimit(length);\n" - "if (!$get_mutable_bit_parser$ && input.getBytesUntilLimit() > 0) {\n" - " $name$_ = $create_list$;\n" - " $set_mutable_bit_parser$;\n" - "}\n" - "while (input.getBytesUntilLimit() > 0) {\n" - " $repeated_add$(input.read$capitalized_type$());\n" - "}\n" - "input.popLimit(limit);\n"); -} - -void RepeatedImmutablePrimitiveFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { +void RepeatedImmutablePrimitiveFieldGenerator:: + GenerateBuilderParsingCodeFromPacked(io::Printer* printer) const { printer->Print(variables_, - "if ($get_mutable_bit_parser$) {\n" - " $name_make_immutable$; // C\n" - "}\n"); + "int length = input.readRawVarint32();\n" + "int limit = input.pushLimit(length);\n" + "ensure$capitalized_name$IsMutable();\n" + "while (input.getBytesUntilLimit() > 0) {\n" + " $repeated_add$(input.read$capitalized_type$());\n" + "}\n" + "input.popLimit(limit);\n"); } void RepeatedImmutablePrimitiveFieldGenerator::GenerateSerializationCode( diff --git a/src/google/protobuf/compiler/java/java_primitive_field.h b/src/google/protobuf/compiler/java/java_primitive_field.h index db20750e262d..e74044a92c58 100644 --- a/src/google/protobuf/compiler/java/java_primitive_field.h +++ b/src/google/protobuf/compiler/java/java_primitive_field.h @@ -65,24 +65,24 @@ class ImmutablePrimitiveFieldGenerator : public ImmutableFieldGenerator { // implements ImmutableFieldGenerator // --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; protected: const FieldDescriptor* descriptor_; @@ -101,13 +101,14 @@ class ImmutablePrimitiveOneofFieldGenerator int builderBitIndex, Context* context); ~ImmutablePrimitiveOneofFieldGenerator(); - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; private: GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ImmutablePrimitiveOneofFieldGenerator); @@ -122,25 +123,26 @@ class RepeatedImmutablePrimitiveFieldGenerator virtual ~RepeatedImmutablePrimitiveFieldGenerator(); // implements ImmutableFieldGenerator --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingCodeFromPacked(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCodeFromPacked( + io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; private: const FieldDescriptor* descriptor_; diff --git a/src/google/protobuf/compiler/java/java_string_field.cc b/src/google/protobuf/compiler/java/java_string_field.cc index 2e9a9e789a0f..722db1dd8af7 100644 --- a/src/google/protobuf/compiler/java/java_string_field.cc +++ b/src/google/protobuf/compiler/java/java_string_field.cc @@ -120,13 +120,6 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor, (*variables)["set_mutable_bit_builder"] = GenerateSetBit(builderBitIndex); (*variables)["clear_mutable_bit_builder"] = GenerateClearBit(builderBitIndex); - // For repeated fields, one bit is used for whether the array is immutable - // in the parsing constructor. - (*variables)["get_mutable_bit_parser"] = - GenerateGetBitMutableLocal(builderBitIndex); - (*variables)["set_mutable_bit_parser"] = - GenerateSetBitMutableLocal(builderBitIndex); - (*variables)["get_has_field_bit_from_local"] = GenerateGetBitFromLocal(builderBitIndex); (*variables)["set_has_field_bit_to_local"] = @@ -415,26 +408,19 @@ void ImmutableStringFieldGenerator::GenerateBuildingCode( printer->Print(variables_, "result.$name$_ = $name$_;\n"); } -void ImmutableStringFieldGenerator::GenerateParsingCode( +void ImmutableStringFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { if (CheckUtf8(descriptor_)) { printer->Print(variables_, - "java.lang.String s = input.readStringRequireUtf8();\n" - "$set_has_field_bit_message$\n" - "$name$_ = s;\n"); + "$name$_ = input.readStringRequireUtf8();\n" + "$set_has_field_bit_builder$\n"); } else { printer->Print(variables_, - "com.google.protobuf.ByteString bs = input.readBytes();\n" - "$set_has_field_bit_message$\n" - "$name$_ = bs;\n"); + "$name$_ = input.readBytes();\n" + "$set_has_field_bit_builder$\n"); } } -void ImmutableStringFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - // noop for strings. -} - void ImmutableStringFieldGenerator::GenerateSerializationCode( io::Printer* printer) const { printer->Print(variables_, @@ -658,6 +644,11 @@ void ImmutableStringOneofFieldGenerator::GenerateBuilderMembers( "}\n"); } +void ImmutableStringOneofFieldGenerator::GenerateBuilderClearCode( + io::Printer* printer) const { + // No-Op: String fields in oneofs are correctly cleared by clearing the oneof +} + void ImmutableStringOneofFieldGenerator::GenerateMergingCode( io::Printer* printer) const { // Allow a slight breach of abstraction here in order to avoid forcing @@ -676,7 +667,7 @@ void ImmutableStringOneofFieldGenerator::GenerateBuildingCode( "}\n"); } -void ImmutableStringOneofFieldGenerator::GenerateParsingCode( +void ImmutableStringOneofFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { if (CheckUtf8(descriptor_)) { printer->Print(variables_, @@ -966,35 +957,21 @@ void RepeatedImmutableStringFieldGenerator::GenerateBuildingCode( "result.$name$_ = $name$_;\n"); } -void RepeatedImmutableStringFieldGenerator::GenerateParsingCode( +void RepeatedImmutableStringFieldGenerator::GenerateBuilderParsingCode( io::Printer* printer) const { if (CheckUtf8(descriptor_)) { printer->Print(variables_, - "java.lang.String s = input.readStringRequireUtf8();\n"); + "java.lang.String s = input.readStringRequireUtf8();\n" + "ensure$capitalized_name$IsMutable();\n" + "$name$_.add(s);\n"); } else { printer->Print(variables_, - "com.google.protobuf.ByteString bs = input.readBytes();\n"); - } - printer->Print(variables_, - "if (!$get_mutable_bit_parser$) {\n" - " $name$_ = new com.google.protobuf.LazyStringArrayList();\n" - " $set_mutable_bit_parser$;\n" - "}\n"); - if (CheckUtf8(descriptor_)) { - printer->Print(variables_, "$name$_.add(s);\n"); - } else { - printer->Print(variables_, "$name$_.add(bs);\n"); + "com.google.protobuf.ByteString bs = input.readBytes();\n" + "ensure$capitalized_name$IsMutable();\n" + "$name$_.add(bs);\n"); } } -void RepeatedImmutableStringFieldGenerator::GenerateParsingDoneCode( - io::Printer* printer) const { - printer->Print(variables_, - "if ($get_mutable_bit_parser$) {\n" - " $name$_ = $name$_.getUnmodifiableView();\n" - "}\n"); -} - void RepeatedImmutableStringFieldGenerator::GenerateSerializationCode( io::Printer* printer) const { printer->Print(variables_, diff --git a/src/google/protobuf/compiler/java/java_string_field.h b/src/google/protobuf/compiler/java/java_string_field.h index 1c00ae81c291..4aaabfb79db6 100644 --- a/src/google/protobuf/compiler/java/java_string_field.h +++ b/src/google/protobuf/compiler/java/java_string_field.h @@ -65,24 +65,24 @@ class ImmutableStringFieldGenerator : public ImmutableFieldGenerator { // implements ImmutableFieldGenerator // --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; protected: const FieldDescriptor* descriptor_; @@ -102,13 +102,14 @@ class ImmutableStringOneofFieldGenerator ~ImmutableStringOneofFieldGenerator(); private: - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ImmutableStringOneofFieldGenerator); }; @@ -121,24 +122,24 @@ class RepeatedImmutableStringFieldGenerator : public ImmutableFieldGenerator { ~RepeatedImmutableStringFieldGenerator(); // implements ImmutableFieldGenerator --------------------------------------- - int GetNumBitsForMessage() const; - int GetNumBitsForBuilder() const; - void GenerateInterfaceMembers(io::Printer* printer) const; - void GenerateMembers(io::Printer* printer) const; - void GenerateBuilderMembers(io::Printer* printer) const; - void GenerateInitializationCode(io::Printer* printer) const; - void GenerateBuilderClearCode(io::Printer* printer) const; - void GenerateMergingCode(io::Printer* printer) const; - void GenerateBuildingCode(io::Printer* printer) const; - void GenerateParsingCode(io::Printer* printer) const; - void GenerateParsingDoneCode(io::Printer* printer) const; - void GenerateSerializationCode(io::Printer* printer) const; - void GenerateSerializedSizeCode(io::Printer* printer) const; - void GenerateFieldBuilderInitializationCode(io::Printer* printer) const; - void GenerateEqualsCode(io::Printer* printer) const; - void GenerateHashCode(io::Printer* printer) const; - - std::string GetBoxedType() const; + int GetNumBitsForMessage() const override; + int GetNumBitsForBuilder() const override; + void GenerateInterfaceMembers(io::Printer* printer) const override; + void GenerateMembers(io::Printer* printer) const override; + void GenerateBuilderMembers(io::Printer* printer) const override; + void GenerateInitializationCode(io::Printer* printer) const override; + void GenerateBuilderClearCode(io::Printer* printer) const override; + void GenerateMergingCode(io::Printer* printer) const override; + void GenerateBuildingCode(io::Printer* printer) const override; + void GenerateBuilderParsingCode(io::Printer* printer) const override; + void GenerateSerializationCode(io::Printer* printer) const override; + void GenerateSerializedSizeCode(io::Printer* printer) const override; + void GenerateFieldBuilderInitializationCode( + io::Printer* printer) const override; + void GenerateEqualsCode(io::Printer* printer) const override; + void GenerateHashCode(io::Printer* printer) const override; + + std::string GetBoxedType() const override; private: const FieldDescriptor* descriptor_;