Skip to content

Commit

Permalink
Apply patch
Browse files Browse the repository at this point in the history
deannagarcia committed Sep 13, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 5b37c91 commit d1635e1
Showing 4 changed files with 149 additions and 35 deletions.
27 changes: 18 additions & 9 deletions src/google/protobuf/extension_set_inl.h
Original file line number Diff line number Diff line change
@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
internal::ParseContext* ctx) {
std::string payload;
uint32_t type_id = 0;
bool payload_read = false;
uint32_t type_id;
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (!ctx->Done(&ptr)) {
uint32_t tag = static_cast<uint8_t>(*ptr++);
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
ExtensionInfo extension;
bool was_packed_on_wire;
if (!FindExtension(2, type_id, extendee, ctx, &extension,
@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id != 0) {
if (state == State::kHasType) {
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
extendee, metadata, ctx);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
type_id = 0;
state = State::kDone;
} else {
std::string tmp;
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
ptr = ctx->ReadString(ptr, size, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
if (state == State::kNoTag) {
payload = std::move(tmp);
state = State::kHasPayload;
}
}
} else {
ptr = ReadTag(ptr - 1, &tag);
26 changes: 18 additions & 8 deletions src/google/protobuf/wire_format.cc
Original file line number Diff line number Diff line change
@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
// Parse a MessageSetItem
auto metadata = reflection->MutableInternalMetadata(msg);
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

std::string payload;
uint32_t type_id = 0;
bool payload_read = false;
while (!ctx->Done(&ptr)) {
// We use 64 bit tags in order to allow typeid's that span the whole
// range of 32 bit numbers.
@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
const FieldDescriptor* field;
if (ctx->data().pool == nullptr) {
field = reflection->FindKnownExtensionByNumber(type_id);
@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
continue;
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id == 0) {
if (state == State::kNoTag) {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
} else {
state = State::kHasPayload;
} else if (state == State::kHasType) {
// We're now parsing the payload
const FieldDescriptor* field = nullptr;
if (descriptor->IsExtensionNumber(type_id)) {
@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
ptr = WireFormat::_InternalParseAndMergeField(
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
field);
type_id = 0;
state = State::kDone;
} else {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->Skip(ptr, size);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
}
} else {
// An unknown field in MessageSetItem.
27 changes: 18 additions & 9 deletions src/google/protobuf/wire_format_lite.h
Original file line number Diff line number Diff line change
@@ -1845,6 +1845,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
// we can parse it later.
std::string message_data;

enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (true) {
const uint32_t tag = input->ReadTagNoLastTag();
if (tag == 0) return false;
@@ -1853,26 +1856,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
case WireFormatLite::kMessageSetTypeIdTag: {
uint32_t type_id;
if (!input->ReadVarint32(&type_id)) return false;
last_type_id = type_id;

if (!message_data.empty()) {
if (state == State::kNoTag) {
last_type_id = type_id;
state = State::kHasType;
} else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it
// now.
io::CodedInputStream sub_input(
reinterpret_cast<const uint8_t*>(message_data.data()),
static_cast<int>(message_data.size()));
sub_input.SetRecursionLimit(input->RecursionBudget());
if (!ms.ParseField(last_type_id, &sub_input)) {
if (!ms.ParseField(type_id, &sub_input)) {
return false;
}
message_data.clear();
state = State::kDone;
}

break;
}

case WireFormatLite::kMessageSetMessageTag: {
if (last_type_id == 0) {
if (state == State::kHasType) {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
state = State::kDone;
} else if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data.
uint32_t length;
if (!input->ReadVarint32(&length)) return false;
@@ -1883,11 +1894,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
if (!input->ReadRaw(ptr, length)) return false;
state = State::kHasPayload;
} else {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
if (!ms.SkipField(tag, input)) return false;
}

break;
104 changes: 95 additions & 9 deletions src/google/protobuf/wire_format_unittest.inc
Original file line number Diff line number Diff line change
@@ -580,28 +580,54 @@ TEST(WireFormatTest, ParseMessageSet) {
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
}

TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
namespace {
std::string BuildMessageSetItemStart() {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(123);
// Build a MessageSet manually with its message content put before its
// type_id.
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
}
return data;
}
std::string BuildMessageSetItemEnd() {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
}
return data;
}
std::string BuildMessageSetTestExtension1(int value = 123) {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(value);
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
// Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output);
coded_output.WriteVarint32(message.ByteSizeLong());
message.SerializeWithCachedSizes(&coded_output);
// Write the type id.
uint32 type_id = message.GetDescriptor()->extension(0)->number();
}
return data;
}
std::string BuildMessageSetItemTypeId(int extension_number) {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
type_id, &coded_output);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
extension_number, &coded_output);
}
return data;
}
void ValidateTestMessageSet(const std::string& test_case,
const std::string& data) {
SCOPED_TRACE(test_case);
{
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data));
@@ -611,6 +637,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
.GetExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension)
.i());

// Make sure it does not contain anything else.
message_set.ClearExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension);
EXPECT_EQ(message_set.SerializeAsString(), "");
}
{
// Test parse the message via Reflection.
@@ -626,6 +657,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
UNITTEST::TestMessageSetExtension1::message_set_extension)
.i());
}
{
// Test parse the message via DynamicMessage.
DynamicMessageFactory factory;
std::unique_ptr<Message> msg(
factory
.GetPrototype(
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
->New());
msg->ParseFromString(data);
auto* reflection = msg->GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(*msg, &fields);
ASSERT_EQ(fields.size(), 1);
const auto& sub = reflection->GetMessage(*msg, fields[0]);
reflection = sub.GetReflection();
EXPECT_EQ(123, reflection->GetInt32(
sub, sub.GetDescriptor()->FindFieldByName("i")));
}
}
} // namespace

TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string message = BuildMessageSetTestExtension1();

ValidateTestMessageSet("id + message", start + id + message + end);
ValidateTestMessageSet("message + id", start + message + id + end);
}

TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string other_id = BuildMessageSetItemTypeId(123456);
std::string message = BuildMessageSetTestExtension1();
std::string other_message = BuildMessageSetTestExtension1(321);

// Double id
ValidateTestMessageSet("id + other_id + message",
start + id + other_id + message + end);
ValidateTestMessageSet("id + message + other_id",
start + id + message + other_id + end);
ValidateTestMessageSet("message + id + other_id",
start + message + id + other_id + end);
// Double message
ValidateTestMessageSet("id + message + other_message",
start + id + message + other_message + end);
ValidateTestMessageSet("message + id + other_message",
start + message + id + other_message + end);
ValidateTestMessageSet("message + other_message + id",
start + message + other_message + id + end);
}

void SerializeReverseOrder(

0 comments on commit d1635e1

Please sign in to comment.