@@ -184,7 +184,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
184184 clear_if_default = False ):
185185 if is_packed :
186186 local_DecodeVarint = _DecodeVarint
187- def DecodePackedField (buffer , pos , end , message , field_dict ):
187+ def DecodePackedField (
188+ buffer , pos , end , message , field_dict , current_depth = 0
189+ ):
190+ del current_depth # unused
188191 value = field_dict .get (key )
189192 if value is None :
190193 value = field_dict .setdefault (key , new_default (message ))
@@ -199,11 +202,15 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
199202 del value [- 1 ] # Discard corrupt value.
200203 raise _DecodeError ('Packed element was truncated.' )
201204 return pos
205+
202206 return DecodePackedField
203207 elif is_repeated :
204208 tag_bytes = encoder .TagBytes (field_number , wire_type )
205209 tag_len = len (tag_bytes )
206- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
210+ def DecodeRepeatedField (
211+ buffer , pos , end , message , field_dict , current_depth = 0
212+ ):
213+ del current_depth # unused
207214 value = field_dict .get (key )
208215 if value is None :
209216 value = field_dict .setdefault (key , new_default (message ))
@@ -218,9 +225,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218225 if new_pos > end :
219226 raise _DecodeError ('Truncated message.' )
220227 return new_pos
228+
221229 return DecodeRepeatedField
222230 else :
223- def DecodeField (buffer , pos , end , message , field_dict ):
231+
232+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
233+ del current_depth # unused
224234 (new_value , pos ) = decode_value (buffer , pos )
225235 if pos > end :
226236 raise _DecodeError ('Truncated message.' )
@@ -229,6 +239,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
229239 else :
230240 field_dict [key ] = new_value
231241 return pos
242+
232243 return DecodeField
233244
234245 return SpecificDecoder
@@ -364,7 +375,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
364375 enum_type = key .enum_type
365376 if is_packed :
366377 local_DecodeVarint = _DecodeVarint
367- def DecodePackedField (buffer , pos , end , message , field_dict ):
378+ def DecodePackedField (
379+ buffer , pos , end , message , field_dict , current_depth = 0
380+ ):
368381 """Decode serialized packed enum to its value and a new position.
369382
370383 Args:
@@ -377,6 +390,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
377390 Returns:
378391 int, new position in serialized data.
379392 """
393+ del current_depth # unused
380394 value = field_dict .get (key )
381395 if value is None :
382396 value = field_dict .setdefault (key , new_default (message ))
@@ -407,11 +421,14 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
407421 # pylint: enable=protected-access
408422 raise _DecodeError ('Packed element was truncated.' )
409423 return pos
424+
410425 return DecodePackedField
411426 elif is_repeated :
412427 tag_bytes = encoder .TagBytes (field_number , wire_format .WIRETYPE_VARINT )
413428 tag_len = len (tag_bytes )
414- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
429+ def DecodeRepeatedField (
430+ buffer , pos , end , message , field_dict , current_depth = 0
431+ ):
415432 """Decode serialized repeated enum to its value and a new position.
416433
417434 Args:
@@ -424,6 +441,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
424441 Returns:
425442 int, new position in serialized data.
426443 """
444+ del current_depth # unused
427445 value = field_dict .get (key )
428446 if value is None :
429447 value = field_dict .setdefault (key , new_default (message ))
@@ -446,9 +464,11 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446464 if new_pos > end :
447465 raise _DecodeError ('Truncated message.' )
448466 return new_pos
467+
449468 return DecodeRepeatedField
450469 else :
451- def DecodeField (buffer , pos , end , message , field_dict ):
470+
471+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
452472 """Decode serialized repeated enum to its value and a new position.
453473
454474 Args:
@@ -461,6 +481,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
461481 Returns:
462482 int, new position in serialized data.
463483 """
484+ del current_depth # unused
464485 value_start_pos = pos
465486 (enum_value , pos ) = _DecodeSignedVarint32 (buffer , pos )
466487 if pos > end :
@@ -480,6 +501,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
480501 (tag_bytes , buffer [value_start_pos :pos ].tobytes ()))
481502 # pylint: enable=protected-access
482503 return pos
504+
483505 return DecodeField
484506
485507
@@ -538,7 +560,10 @@ def _ConvertToUnicode(memview):
538560 tag_bytes = encoder .TagBytes (field_number ,
539561 wire_format .WIRETYPE_LENGTH_DELIMITED )
540562 tag_len = len (tag_bytes )
541- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
563+ def DecodeRepeatedField (
564+ buffer , pos , end , message , field_dict , current_depth = 0
565+ ):
566+ del current_depth # unused
542567 value = field_dict .get (key )
543568 if value is None :
544569 value = field_dict .setdefault (key , new_default (message ))
@@ -553,9 +578,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
553578 if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
554579 # Prediction failed. Return.
555580 return new_pos
581+
556582 return DecodeRepeatedField
557583 else :
558- def DecodeField (buffer , pos , end , message , field_dict ):
584+
585+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
586+ del current_depth # unused
559587 (size , pos ) = local_DecodeVarint (buffer , pos )
560588 new_pos = pos + size
561589 if new_pos > end :
@@ -565,6 +593,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
565593 else :
566594 field_dict [key ] = _ConvertToUnicode (buffer [pos :new_pos ])
567595 return new_pos
596+
568597 return DecodeField
569598
570599
@@ -579,7 +608,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
579608 tag_bytes = encoder .TagBytes (field_number ,
580609 wire_format .WIRETYPE_LENGTH_DELIMITED )
581610 tag_len = len (tag_bytes )
582- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
611+ def DecodeRepeatedField (
612+ buffer , pos , end , message , field_dict , current_depth = 0
613+ ):
614+ del current_depth # unused
583615 value = field_dict .get (key )
584616 if value is None :
585617 value = field_dict .setdefault (key , new_default (message ))
@@ -594,9 +626,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
594626 if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
595627 # Prediction failed. Return.
596628 return new_pos
629+
597630 return DecodeRepeatedField
598631 else :
599- def DecodeField (buffer , pos , end , message , field_dict ):
632+
633+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
634+ del current_depth # unused
600635 (size , pos ) = local_DecodeVarint (buffer , pos )
601636 new_pos = pos + size
602637 if new_pos > end :
@@ -606,6 +641,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
606641 else :
607642 field_dict [key ] = buffer [pos :new_pos ].tobytes ()
608643 return new_pos
644+
609645 return DecodeField
610646
611647
@@ -621,7 +657,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
621657 tag_bytes = encoder .TagBytes (field_number ,
622658 wire_format .WIRETYPE_START_GROUP )
623659 tag_len = len (tag_bytes )
624- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
660+ def DecodeRepeatedField (
661+ buffer , pos , end , message , field_dict , current_depth = 0
662+ ):
625663 value = field_dict .get (key )
626664 if value is None :
627665 value = field_dict .setdefault (key , new_default (message ))
@@ -630,7 +668,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
630668 if value is None :
631669 value = field_dict .setdefault (key , new_default (message ))
632670 # Read sub-message.
633- pos = value .add ()._InternalParse (buffer , pos , end )
671+ current_depth += 1
672+ if current_depth > _recursion_limit :
673+ raise _DecodeError (
674+ 'Error parsing message: too many levels of nesting.'
675+ )
676+ pos = value .add ()._InternalParse (buffer , pos , end , current_depth )
677+ current_depth -= 1
634678 # Read end tag.
635679 new_pos = pos + end_tag_len
636680 if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
@@ -640,19 +684,26 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
640684 if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
641685 # Prediction failed. Return.
642686 return new_pos
687+
643688 return DecodeRepeatedField
644689 else :
645- def DecodeField (buffer , pos , end , message , field_dict ):
690+
691+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
646692 value = field_dict .get (key )
647693 if value is None :
648694 value = field_dict .setdefault (key , new_default (message ))
649695 # Read sub-message.
650- pos = value ._InternalParse (buffer , pos , end )
696+ current_depth += 1
697+ if current_depth > _recursion_limit :
698+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
699+ pos = value ._InternalParse (buffer , pos , end , current_depth )
700+ current_depth -= 1
651701 # Read end tag.
652702 new_pos = pos + end_tag_len
653703 if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
654704 raise _DecodeError ('Missing group end tag.' )
655705 return new_pos
706+
656707 return DecodeField
657708
658709
@@ -666,7 +717,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
666717 tag_bytes = encoder .TagBytes (field_number ,
667718 wire_format .WIRETYPE_LENGTH_DELIMITED )
668719 tag_len = len (tag_bytes )
669- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
720+ def DecodeRepeatedField (
721+ buffer , pos , end , message , field_dict , current_depth = 0
722+ ):
670723 value = field_dict .get (key )
671724 if value is None :
672725 value = field_dict .setdefault (key , new_default (message ))
@@ -677,18 +730,29 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
677730 if new_pos > end :
678731 raise _DecodeError ('Truncated message.' )
679732 # Read sub-message.
680- if value .add ()._InternalParse (buffer , pos , new_pos ) != new_pos :
733+ current_depth += 1
734+ if current_depth > _recursion_limit :
735+ raise _DecodeError (
736+ 'Error parsing message: too many levels of nesting.'
737+ )
738+ if (
739+ value .add ()._InternalParse (buffer , pos , new_pos , current_depth )
740+ != new_pos
741+ ):
681742 # The only reason _InternalParse would return early is if it
682743 # encountered an end-group tag.
683744 raise _DecodeError ('Unexpected end-group tag.' )
745+ current_depth -= 1
684746 # Predict that the next tag is another copy of the same repeated field.
685747 pos = new_pos + tag_len
686748 if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
687749 # Prediction failed. Return.
688750 return new_pos
751+
689752 return DecodeRepeatedField
690753 else :
691- def DecodeField (buffer , pos , end , message , field_dict ):
754+
755+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
692756 value = field_dict .get (key )
693757 if value is None :
694758 value = field_dict .setdefault (key , new_default (message ))
@@ -698,11 +762,16 @@ def DecodeField(buffer, pos, end, message, field_dict):
698762 if new_pos > end :
699763 raise _DecodeError ('Truncated message.' )
700764 # Read sub-message.
701- if value ._InternalParse (buffer , pos , new_pos ) != new_pos :
765+ current_depth += 1
766+ if current_depth > _recursion_limit :
767+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
768+ if value ._InternalParse (buffer , pos , new_pos , current_depth ) != new_pos :
702769 # The only reason _InternalParse would return early is if it encountered
703770 # an end-group tag.
704771 raise _DecodeError ('Unexpected end-group tag.' )
772+ current_depth -= 1
705773 return new_pos
774+
706775 return DecodeField
707776
708777
@@ -851,7 +920,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
851920 # Can't read _concrete_class yet; might not be initialized.
852921 message_type = field_descriptor .message_type
853922
854- def DecodeMap (buffer , pos , end , message , field_dict ):
923+ def DecodeMap (buffer , pos , end , message , field_dict , current_depth = 0 ):
924+ del current_depth # Unused.
855925 submsg = message_type ._concrete_class ()
856926 value = field_dict .get (key )
857927 if value is None :
@@ -934,7 +1004,16 @@ def _SkipGroup(buffer, pos, end):
9341004 pos = new_pos
9351005
9361006
937- def _DecodeUnknownFieldSet (buffer , pos , end_pos = None ):
1007+ DEFAULT_RECURSION_LIMIT = 100
1008+ _recursion_limit = DEFAULT_RECURSION_LIMIT
1009+
1010+
1011+ def SetRecursionLimit (new_limit ):
1012+ global _recursion_limit
1013+ _recursion_limit = new_limit
1014+
1015+
1016+ def _DecodeUnknownFieldSet (buffer , pos , end_pos = None , current_depth = 0 ):
9381017 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9391018
9401019 unknown_field_set = containers .UnknownFieldSet ()
@@ -944,14 +1023,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
9441023 field_number , wire_type = wire_format .UnpackTag (tag )
9451024 if wire_type == wire_format .WIRETYPE_END_GROUP :
9461025 break
947- (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type )
1026+ (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type , current_depth )
9481027 # pylint: disable=protected-access
9491028 unknown_field_set ._add (field_number , wire_type , data )
9501029
9511030 return (unknown_field_set , pos )
9521031
9531032
954- def _DecodeUnknownField (buffer , pos , wire_type ):
1033+ def _DecodeUnknownField (
1034+ buffer , pos , wire_type , current_depth = 0
1035+ ):
9551036 """Decode a unknown field. Returns the UnknownField and new position."""
9561037
9571038 if wire_type == wire_format .WIRETYPE_VARINT :
@@ -965,7 +1046,11 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9651046 data = buffer [pos :pos + size ].tobytes ()
9661047 pos += size
9671048 elif wire_type == wire_format .WIRETYPE_START_GROUP :
968- (data , pos ) = _DecodeUnknownFieldSet (buffer , pos )
1049+ current_depth += 1
1050+ if current_depth >= _recursion_limit :
1051+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
1052+ data , pos = _DecodeUnknownFieldSet (buffer , pos , None , current_depth )
1053+ current_depth -= 1
9691054 elif wire_type == wire_format .WIRETYPE_END_GROUP :
9701055 return (0 , - 1 )
9711056 else :
0 commit comments