@@ -511,6 +511,13 @@ impl<'a> CodedInputStream<'a> {
511
511
}
512
512
513
513
fn skip_group ( & mut self ) -> crate :: Result < ( ) > {
514
+ self . incr_recursion ( ) ?;
515
+ let ret = self . skip_group_no_depth_check ( ) ;
516
+ self . decr_recursion ( ) ;
517
+ ret
518
+ }
519
+
520
+ fn skip_group_no_depth_check ( & mut self ) -> crate :: Result < ( ) > {
514
521
while !self . eof ( ) ? {
515
522
let wire_type = self . read_tag_unpack ( ) ?. 1 ;
516
523
if wire_type == WireType :: EndGroup {
@@ -631,19 +638,16 @@ impl<'a> CodedInputStream<'a> {
631
638
/// Read message, do not check if message is initialized
632
639
pub fn merge_message < M : Message > ( & mut self , message : & mut M ) -> crate :: Result < ( ) > {
633
640
self . incr_recursion ( ) ?;
634
- struct DecrRecursion < ' a , ' b > ( & ' a mut CodedInputStream < ' b > ) ;
635
- impl < ' a , ' b > Drop for DecrRecursion < ' a , ' b > {
636
- fn drop ( & mut self ) {
637
- self . 0 . decr_recursion ( ) ;
638
- }
639
- }
640
-
641
- let mut decr = DecrRecursion ( self ) ;
641
+ let ret = self . merge_message_no_depth_check ( message) ;
642
+ self . decr_recursion ( ) ;
643
+ ret
644
+ }
642
645
643
- let len = decr. 0 . read_raw_varint64 ( ) ?;
644
- let old_limit = decr. 0 . push_limit ( len) ?;
645
- message. merge_from ( & mut decr. 0 ) ?;
646
- decr. 0 . pop_limit ( old_limit) ;
646
+ fn merge_message_no_depth_check < M : Message > ( & mut self , message : & mut M ) -> crate :: Result < ( ) > {
647
+ let len = self . read_raw_varint64 ( ) ?;
648
+ let old_limit = self . push_limit ( len) ?;
649
+ message. merge_from ( self ) ?;
650
+ self . pop_limit ( old_limit) ;
647
651
Ok ( ( ) )
648
652
}
649
653
@@ -982,4 +986,47 @@ mod test {
982
986
) ;
983
987
assert_eq ! ( "field 3" , input. read_string( ) . unwrap( ) ) ;
984
988
}
989
+
990
+ #[ test]
991
+ fn test_shallow_nested_unknown_groups ( ) {
992
+ // Test skip_group() succeeds on a start group tag 50 times
993
+ // followed by end group tag 50 times. We should be able to
994
+ // successfully skip the outermost group.
995
+ let mut vec = Vec :: new ( ) ;
996
+ let mut os = CodedOutputStream :: new ( & mut vec) ;
997
+ for _ in 0 ..50 {
998
+ os. write_tag ( 1 , WireType :: StartGroup ) . unwrap ( ) ;
999
+ }
1000
+ for _ in 0 ..50 {
1001
+ os. write_tag ( 1 , WireType :: EndGroup ) . unwrap ( ) ;
1002
+ }
1003
+ drop ( os) ;
1004
+
1005
+ let mut input = CodedInputStream :: from_bytes ( & vec) ;
1006
+ assert ! ( input. skip_group( ) . is_ok( ) ) ;
1007
+ }
1008
+
1009
+ #[ test]
1010
+ fn test_deeply_nested_unknown_groups ( ) {
1011
+ // Create an output stream that has groups nested recursively 1000
1012
+ // deep, and try to skip the group.
1013
+ // This should fail the default depth limit of 100 which ensures we
1014
+ // don't blow the stack on adversial input.
1015
+ let mut vec = Vec :: new ( ) ;
1016
+ let mut os = CodedOutputStream :: new ( & mut vec) ;
1017
+ for _ in 0 ..1000 {
1018
+ os. write_tag ( 1 , WireType :: StartGroup ) . unwrap ( ) ;
1019
+ }
1020
+ for _ in 0 ..1000 {
1021
+ os. write_tag ( 1 , WireType :: EndGroup ) . unwrap ( ) ;
1022
+ }
1023
+ drop ( os) ;
1024
+
1025
+ let mut input = CodedInputStream :: from_bytes ( & vec) ;
1026
+ assert ! ( input
1027
+ . skip_group( )
1028
+ . unwrap_err( )
1029
+ . to_string( )
1030
+ . contains( "Over recursion limit" ) ) ;
1031
+ }
985
1032
}
0 commit comments