1111import static com .google .common .truth .Truth .assertWithMessage ;
1212import static org .junit .Assert .assertArrayEquals ;
1313import static org .junit .Assert .assertThrows ;
14+
15+ import com .google .common .primitives .Bytes ;
16+ import map_test .MapTestProto .MapContainer ;
1417import protobuf_unittest .UnittestProto .BoolMessage ;
1518import protobuf_unittest .UnittestProto .Int32Message ;
1619import protobuf_unittest .UnittestProto .Int64Message ;
@@ -35,6 +38,13 @@ public class CodedInputStreamTest {
3538
3639 private static final int DEFAULT_BLOCK_SIZE = 4096 ;
3740
41+ private static final int GROUP_TAP = WireFormat .makeTag (3 , WireFormat .WIRETYPE_START_GROUP );
42+
43+ private static final byte [] NESTING_SGROUP = generateSGroupTags ();
44+
45+ private static final byte [] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField ();
46+
47+
3848 private enum InputType {
3949 ARRAY {
4050 @ Override
@@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) {
117127 return bytes ;
118128 }
119129
130+ private static byte [] generateSGroupTags () {
131+ byte [] bytes = new byte [100000 ];
132+ Arrays .fill (bytes , (byte ) GROUP_TAP );
133+ return bytes ;
134+ }
135+
136+ private static byte [] generateSGroupTagsForMapField () {
137+ byte [] initialBytes = {18 , 1 , 75 , 26 , (byte ) 198 , (byte ) 154 , 12 };
138+ return Bytes .concat (initialBytes , NESTING_SGROUP );
139+ }
140+
120141 /**
121142 * An InputStream which limits the number of bytes it reads at a time. We use this to make sure
122143 * that CodedInputStream doesn't screw up when reading in small blocks.
@@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception {
740761 }
741762 }
742763
764+ @ Test
765+ public void testMaliciousRecursion_unknownFields () throws Exception {
766+ Throwable thrown =
767+ assertThrows (
768+ InvalidProtocolBufferException .class ,
769+ () -> TestRecursiveMessage .parseFrom (NESTING_SGROUP ));
770+
771+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
772+ }
773+
774+ @ Test
775+ public void testMaliciousRecursion_skippingUnknownField () throws Exception {
776+ Throwable thrown =
777+ assertThrows (
778+ InvalidProtocolBufferException .class ,
779+ () ->
780+ DiscardUnknownFieldsParser .wrap (TestRecursiveMessage .parser ())
781+ .parseFrom (NESTING_SGROUP ));
782+
783+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
784+ }
785+
786+ @ Test
787+ public void testMaliciousSGroupTagsWithMapField_fromInputStream () throws Exception {
788+ Throwable parseFromThrown =
789+ assertThrows (
790+ InvalidProtocolBufferException .class ,
791+ () ->
792+ MapContainer .parseFrom (
793+ new ByteArrayInputStream (NESTING_SGROUP_WITH_INITIAL_BYTES )));
794+ Throwable mergeFromThrown =
795+ assertThrows (
796+ InvalidProtocolBufferException .class ,
797+ () ->
798+ MapContainer .newBuilder ()
799+ .mergeFrom (new ByteArrayInputStream (NESTING_SGROUP_WITH_INITIAL_BYTES )));
800+
801+ assertThat (parseFromThrown )
802+ .hasMessageThat ()
803+ .contains ("Protocol message had too many levels of nesting" );
804+ assertThat (mergeFromThrown )
805+ .hasMessageThat ()
806+ .contains ("Protocol message had too many levels of nesting" );
807+ }
808+
809+ @ Test
810+ public void testMaliciousSGroupTags_inputStream_skipMessage () throws Exception {
811+ ByteArrayInputStream inputSteam = new ByteArrayInputStream (NESTING_SGROUP );
812+ CodedInputStream input = CodedInputStream .newInstance (inputSteam );
813+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
814+
815+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
816+ Throwable thrown2 =
817+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
818+
819+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
820+ assertThat (thrown2 )
821+ .hasMessageThat ()
822+ .contains ("Protocol message had too many levels of nesting" );
823+ }
824+
825+ @ Test
826+ public void testMaliciousSGroupTagsWithMapField_fromByteArray () throws Exception {
827+ Throwable parseFromThrown =
828+ assertThrows (
829+ InvalidProtocolBufferException .class ,
830+ () -> MapContainer .parseFrom (NESTING_SGROUP_WITH_INITIAL_BYTES ));
831+ Throwable mergeFromThrown =
832+ assertThrows (
833+ InvalidProtocolBufferException .class ,
834+ () -> MapContainer .newBuilder ().mergeFrom (NESTING_SGROUP_WITH_INITIAL_BYTES ));
835+
836+ assertThat (parseFromThrown )
837+ .hasMessageThat ()
838+ .contains ("the input ended unexpectedly in the middle of a field" );
839+ assertThat (mergeFromThrown )
840+ .hasMessageThat ()
841+ .contains ("the input ended unexpectedly in the middle of a field" );
842+ }
843+
844+ @ Test
845+ public void testMaliciousSGroupTags_arrayDecoder_skipMessage () throws Exception {
846+ CodedInputStream input = CodedInputStream .newInstance (NESTING_SGROUP );
847+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
848+
849+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
850+ Throwable thrown2 =
851+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
852+
853+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
854+ assertThat (thrown2 )
855+ .hasMessageThat ()
856+ .contains ("Protocol message had too many levels of nesting" );
857+ }
858+
859+ @ Test
860+ public void testMaliciousSGroupTagsWithMapField_fromByteBuffer () throws Exception {
861+ Throwable thrown =
862+ assertThrows (
863+ InvalidProtocolBufferException .class ,
864+ () -> MapContainer .parseFrom (ByteBuffer .wrap (NESTING_SGROUP_WITH_INITIAL_BYTES )));
865+
866+ assertThat (thrown )
867+ .hasMessageThat ()
868+ .contains ("the input ended unexpectedly in the middle of a field" );
869+ }
870+
871+ @ Test
872+ public void testMaliciousSGroupTags_byteBuffer_skipMessage () throws Exception {
873+ CodedInputStream input = InputType .NIO_DIRECT .newDecoder (NESTING_SGROUP );
874+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
875+
876+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
877+ Throwable thrown2 =
878+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
879+
880+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
881+ assertThat (thrown2 )
882+ .hasMessageThat ()
883+ .contains ("Protocol message had too many levels of nesting" );
884+ }
885+
886+ @ Test
887+ public void testMaliciousSGroupTags_iterableByteBuffer () throws Exception {
888+ CodedInputStream input = InputType .ITER_DIRECT .newDecoder (NESTING_SGROUP );
889+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
890+
891+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
892+ Throwable thrown2 =
893+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
894+
895+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
896+ assertThat (thrown2 )
897+ .hasMessageThat ()
898+ .contains ("Protocol message had too many levels of nesting" );
899+ }
900+
743901 private void checkSizeLimitExceeded (InvalidProtocolBufferException e ) {
744902 assertThat (e )
745903 .hasMessageThat ()
0 commit comments