Skip to content

Commit ac9fb5b

Browse files
protobuf-github-botzhangskz
authored andcommitted
Add recursion check when parsing unknown fields in Java.
PiperOrigin-RevId: 675657198
1 parent 9a5f5fe commit ac9fb5b

File tree

8 files changed

+458
-12
lines changed

8 files changed

+458
-12
lines changed

java/core/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ junit_tests(
616616
"src/test/java/com/google/protobuf/DescriptorsTest.java",
617617
"src/test/java/com/google/protobuf/DebugFormatTest.java",
618618
"src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
619+
"src/test/java/com/google/protobuf/CodedInputStreamTest.java",
619620
"src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
620621
# Excluded in core_tests
621622
"src/test/java/com/google/protobuf/DecodeUtf8Test.java",
@@ -664,6 +665,7 @@ junit_tests(
664665
"src/test/java/com/google/protobuf/DescriptorsTest.java",
665666
"src/test/java/com/google/protobuf/DebugFormatTest.java",
666667
"src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
668+
"src/test/java/com/google/protobuf/CodedInputStreamTest.java",
667669
"src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
668670
# Excluded in core_tests
669671
"src/test/java/com/google/protobuf/DecodeUtf8Test.java",

java/core/src/main/java/com/google/protobuf/ArrayDecoders.java

+28
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
*/
2424
@CheckReturnValue
2525
final class ArrayDecoders {
26+
static final int DEFAULT_RECURSION_LIMIT = 100;
27+
28+
@SuppressWarnings("NonFinalStaticField")
29+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
2630

2731
private ArrayDecoders() {}
2832

@@ -37,6 +41,7 @@ static final class Registers {
3741
public long long1;
3842
public Object object1;
3943
public final ExtensionRegistryLite extensionRegistry;
44+
public int recursionDepth;
4045

4146
Registers() {
4247
this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
@@ -244,7 +249,10 @@ static int mergeMessageField(
244249
if (length < 0 || length > limit - position) {
245250
throw InvalidProtocolBufferException.truncatedMessage();
246251
}
252+
registers.recursionDepth++;
253+
checkRecursionLimit(registers.recursionDepth);
247254
schema.mergeFrom(msg, data, position, position + length, registers);
255+
registers.recursionDepth--;
248256
registers.object1 = msg;
249257
return position + length;
250258
}
@@ -262,8 +270,11 @@ static int mergeGroupField(
262270
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
263271
// and it can't be used in group fields).
264272
final MessageSchema messageSchema = (MessageSchema) schema;
273+
registers.recursionDepth++;
274+
checkRecursionLimit(registers.recursionDepth);
265275
final int endPosition =
266276
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
277+
registers.recursionDepth--;
267278
registers.object1 = msg;
268279
return endPosition;
269280
}
@@ -1024,6 +1035,8 @@ static int decodeUnknownField(
10241035
final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
10251036
final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
10261037
int lastTag = 0;
1038+
registers.recursionDepth++;
1039+
checkRecursionLimit(registers.recursionDepth);
10271040
while (position < limit) {
10281041
position = decodeVarint32(data, position, registers);
10291042
lastTag = registers.int1;
@@ -1032,6 +1045,7 @@ static int decodeUnknownField(
10321045
}
10331046
position = decodeUnknownField(lastTag, data, position, limit, child, registers);
10341047
}
1048+
registers.recursionDepth--;
10351049
if (position > limit || lastTag != endGroup) {
10361050
throw InvalidProtocolBufferException.parseFailure();
10371051
}
@@ -1078,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
10781092
throw InvalidProtocolBufferException.invalidTag();
10791093
}
10801094
}
1095+
1096+
/**
1097+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
1098+
* the depth of the message exceeds this limit.
1099+
*/
1100+
public static void setRecursionLimit(int limit) {
1101+
recursionLimit = limit;
1102+
}
1103+
1104+
private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
1105+
if (depth >= recursionLimit) {
1106+
throw InvalidProtocolBufferException.recursionLimitExceeded();
1107+
}
1108+
}
10811109
}

java/core/src/main/java/com/google/protobuf/CodedInputStream.java

+6
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ public void skipMessage() throws IOException {
230230
if (tag == 0) {
231231
return;
232232
}
233+
checkRecursionLimit();
234+
++recursionDepth;
233235
boolean fieldSkipped = skipField(tag);
236+
--recursionDepth;
234237
if (!fieldSkipped) {
235238
return;
236239
}
@@ -247,7 +250,10 @@ public void skipMessage(CodedOutputStream output) throws IOException {
247250
if (tag == 0) {
248251
return;
249252
}
253+
checkRecursionLimit();
254+
++recursionDepth;
250255
boolean fieldSkipped = skipField(tag, output);
256+
--recursionDepth;
251257
if (!fieldSkipped) {
252258
return;
253259
}

java/core/src/main/java/com/google/protobuf/MessageSchema.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -3006,8 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
30063006
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
30073007
}
30083008
// Unknown field.
3009-
3010-
if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3009+
if (unknownFieldSchema.mergeOneFieldFrom(
3010+
unknownFields, reader, /* currentDepth= */ 0)) {
30113011
continue;
30123012
}
30133013
}
@@ -3382,8 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33823382
if (unknownFields == null) {
33833383
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
33843384
}
3385-
3386-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3385+
if (!unknownFieldSchema.mergeOneFieldFrom(
3386+
unknownFields, reader, /* currentDepth= */ 0)) {
33873387
return;
33883388
}
33893389
break;
@@ -3399,8 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33993399
if (unknownFields == null) {
34003400
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
34013401
}
3402-
3403-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3402+
if (!unknownFieldSchema.mergeOneFieldFrom(
3403+
unknownFields, reader, /* currentDepth= */ 0)) {
34043404
return;
34053405
}
34063406
}

java/core/src/main/java/com/google/protobuf/MessageSetSchema.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
278278
reader, extension, extensionRegistry, extensions);
279279
return true;
280280
} else {
281-
282-
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
281+
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
283282
}
284283
} else {
285284
return reader.skipField();

java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java

+25-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
@CheckReturnValue
1414
abstract class UnknownFieldSchema<T, B> {
1515

16+
static final int DEFAULT_RECURSION_LIMIT = 100;
17+
18+
@SuppressWarnings("NonFinalStaticField")
19+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
20+
1621
/** Whether unknown fields should be dropped. */
1722
abstract boolean shouldDiscardUnknownFields(Reader reader);
1823

@@ -55,7 +60,9 @@ abstract class UnknownFieldSchema<T, B> {
5560
/** Marks unknown fields as immutable. */
5661
abstract void makeImmutable(Object message);
5762

58-
final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
63+
/** Merges one field into the unknown fields. */
64+
final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
65+
throws IOException {
5966
int tag = reader.getTag();
6067
int fieldNumber = WireFormat.getTagFieldNumber(tag);
6168
switch (WireFormat.getTagWireType(tag)) {
@@ -74,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
7481
case WireFormat.WIRETYPE_START_GROUP:
7582
final B subFields = newBuilder();
7683
int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
77-
mergeFrom(subFields, reader);
84+
currentDepth++;
85+
if (currentDepth >= recursionLimit) {
86+
throw InvalidProtocolBufferException.recursionLimitExceeded();
87+
}
88+
mergeFrom(subFields, reader, currentDepth);
89+
currentDepth--;
7890
if (endGroupTag != reader.getTag()) {
7991
throw InvalidProtocolBufferException.invalidEndTag();
8092
}
@@ -87,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
8799
}
88100
}
89101

90-
private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
102+
private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
103+
throws IOException {
91104
while (true) {
92105
if (reader.getFieldNumber() == Reader.READ_DONE
93-
|| !mergeOneFieldFrom(unknownFields, reader)) {
106+
|| !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
94107
break;
95108
}
96109
}
@@ -107,4 +120,12 @@ private final void mergeFrom(B unknownFields, Reader reader) throws IOException
107120
abstract int getSerializedSizeAsMessageSet(T message);
108121

109122
abstract int getSerializedSize(T unknowns);
123+
124+
/**
125+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
126+
* the depth of the message exceeds this limit.
127+
*/
128+
public void setRecursionLimit(int limit) {
129+
recursionLimit = limit;
130+
}
110131
}

0 commit comments

Comments
 (0)