23
23
*/
24
24
@ CheckReturnValue
25
25
final class ArrayDecoders {
26
+ static final int DEFAULT_RECURSION_LIMIT = 100 ;
27
+
28
+ @ SuppressWarnings ("NonFinalStaticField" )
29
+ private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT ;
26
30
27
31
private ArrayDecoders () {}
28
32
@@ -37,6 +41,7 @@ static final class Registers {
37
41
public long long1 ;
38
42
public Object object1 ;
39
43
public final ExtensionRegistryLite extensionRegistry ;
44
+ public int recursionDepth ;
40
45
41
46
Registers () {
42
47
this .extensionRegistry = ExtensionRegistryLite .getEmptyRegistry ();
@@ -244,7 +249,10 @@ static int mergeMessageField(
244
249
if (length < 0 || length > limit - position ) {
245
250
throw InvalidProtocolBufferException .truncatedMessage ();
246
251
}
252
+ registers .recursionDepth ++;
253
+ checkRecursionLimit (registers .recursionDepth );
247
254
schema .mergeFrom (msg , data , position , position + length , registers );
255
+ registers .recursionDepth --;
248
256
registers .object1 = msg ;
249
257
return position + length ;
250
258
}
@@ -262,8 +270,11 @@ static int mergeGroupField(
262
270
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
263
271
// and it can't be used in group fields).
264
272
final MessageSchema messageSchema = (MessageSchema ) schema ;
273
+ registers .recursionDepth ++;
274
+ checkRecursionLimit (registers .recursionDepth );
265
275
final int endPosition =
266
276
messageSchema .parseMessage (msg , data , position , limit , endGroup , registers );
277
+ registers .recursionDepth --;
267
278
registers .object1 = msg ;
268
279
return endPosition ;
269
280
}
@@ -1024,6 +1035,8 @@ static int decodeUnknownField(
1024
1035
final UnknownFieldSetLite child = UnknownFieldSetLite .newInstance ();
1025
1036
final int endGroup = (tag & ~0x7 ) | WireFormat .WIRETYPE_END_GROUP ;
1026
1037
int lastTag = 0 ;
1038
+ registers .recursionDepth ++;
1039
+ checkRecursionLimit (registers .recursionDepth );
1027
1040
while (position < limit ) {
1028
1041
position = decodeVarint32 (data , position , registers );
1029
1042
lastTag = registers .int1 ;
@@ -1032,6 +1045,7 @@ static int decodeUnknownField(
1032
1045
}
1033
1046
position = decodeUnknownField (lastTag , data , position , limit , child , registers );
1034
1047
}
1048
+ registers .recursionDepth --;
1035
1049
if (position > limit || lastTag != endGroup ) {
1036
1050
throw InvalidProtocolBufferException .parseFailure ();
1037
1051
}
@@ -1078,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
1078
1092
throw InvalidProtocolBufferException .invalidTag ();
1079
1093
}
1080
1094
}
1095
+
1096
+ /**
1097
+ * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
1098
+ * the depth of the message exceeds this limit.
1099
+ */
1100
+ public static void setRecursionLimit (int limit ) {
1101
+ recursionLimit = limit ;
1102
+ }
1103
+
1104
+ private static void checkRecursionLimit (int depth ) throws InvalidProtocolBufferException {
1105
+ if (depth >= recursionLimit ) {
1106
+ throw InvalidProtocolBufferException .recursionLimitExceeded ();
1107
+ }
1108
+ }
1081
1109
}
0 commit comments