Skip to content
This repository was archived by the owner on Mar 25, 2024. It is now read-only.

Commit b93aff6

Browse files
committed
Prevent too deep recursion
1 parent 49e1e6f commit b93aff6

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

src/de.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct Deserializer<'a> {
7979
aliases: &'a BTreeMap<usize, usize>,
8080
pos: &'a mut usize,
8181
path: Path<'a>,
82+
remaining_depth: u8,
8283
}
8384

8485
impl<'a> Deserializer<'a> {
@@ -109,6 +110,7 @@ impl<'a> Deserializer<'a> {
109110
aliases: self.aliases,
110111
pos: pos,
111112
path: Path::Alias { parent: &self.path },
113+
remaining_depth: self.remaining_depth,
112114
})
113115
}
114116
None => panic!("unresolved alias: {}", *pos),
@@ -161,11 +163,11 @@ impl<'a> Deserializer<'a> {
161163
where
162164
V: Visitor<'de>,
163165
{
164-
let (value, len) = {
165-
let mut seq = SeqAccess { de: self, len: 0 };
166+
let (value, len) = self.recursion_check(|de| {
167+
let mut seq = SeqAccess { de: de, len: 0 };
166168
let value = visitor.visit_seq(&mut seq)?;
167-
(value, seq.len)
168-
};
169+
Ok((value, seq.len))
170+
})?;
169171
self.end_sequence(len)?;
170172
Ok(value)
171173
}
@@ -174,15 +176,15 @@ impl<'a> Deserializer<'a> {
174176
where
175177
V: Visitor<'de>,
176178
{
177-
let (value, len) = {
179+
let (value, len) = self.recursion_check(|de| {
178180
let mut map = MapAccess {
179-
de: &mut *self,
181+
de: de,
180182
len: 0,
181183
key: None,
182184
};
183185
let value = visitor.visit_map(&mut map)?;
184-
(value, map.len)
185-
};
186+
Ok((value, map.len))
187+
})?;
186188
self.end_mapping(len)?;
187189
Ok(value)
188190
}
@@ -238,6 +240,16 @@ impl<'a> Deserializer<'a> {
238240
Err(de::Error::invalid_length(total, &ExpectedMap(len)))
239241
}
240242
}
243+
244+
fn recursion_check<F: FnOnce(&mut Self) -> Result<T>, T>(&mut self, f: F) -> Result<T> {
245+
let previous_depth = self.remaining_depth;
246+
self.remaining_depth = previous_depth
247+
.checked_sub(1)
248+
.ok_or_else(Error::recursion_limit_exceeded)?;
249+
let result = f(self);
250+
self.remaining_depth = previous_depth;
251+
result
252+
}
241253
}
242254

243255
fn visit_scalar<'de, V>(
@@ -303,6 +315,7 @@ impl<'de, 'a, 'r> de::SeqAccess<'de> for SeqAccess<'a, 'r> {
303315
parent: &self.de.path,
304316
index: self.len,
305317
},
318+
remaining_depth: self.de.remaining_depth,
306319
};
307320
self.len += 1;
308321
seed.deserialize(&mut element_de).map(Some)
@@ -357,6 +370,7 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> {
357370
parent: &self.de.path,
358371
}
359372
},
373+
remaining_depth: self.de.remaining_depth,
360374
};
361375
seed.deserialize(&mut value_de)
362376
}
@@ -409,6 +423,7 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> {
409423
parent: &self.de.path,
410424
key: variant,
411425
},
426+
remaining_depth: self.de.remaining_depth,
412427
};
413428
Ok((ret, variant_visitor))
414429
}
@@ -949,6 +964,7 @@ where
949964
aliases: &loader.aliases,
950965
pos: &mut pos,
951966
path: Path::Root,
967+
remaining_depth: 128,
952968
})?;
953969
if pos == loader.events.len() {
954970
Ok(t)

src/error.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub enum ErrorImpl {
4141

4242
EndOfStream,
4343
MoreThanOneDocument,
44+
RecursionLimitExceeded,
4445
}
4546

4647
#[derive(Debug)]
@@ -157,6 +158,12 @@ impl Error {
157158
Error(Box::new(ErrorImpl::FromUtf8(err)))
158159
}
159160

161+
// Not public API. Should be pub(crate).
162+
#[doc(hidden)]
163+
pub fn recursion_limit_exceeded() -> Error {
164+
Error(Box::new(ErrorImpl::RecursionLimitExceeded))
165+
}
166+
160167
// Not public API. Should be pub(crate).
161168
#[doc(hidden)]
162169
pub fn fix_marker(mut self, marker: Marker, path: Path) -> Self {
@@ -183,6 +190,7 @@ impl error::Error for Error {
183190
ErrorImpl::MoreThanOneDocument => {
184191
"deserializing from YAML containing more than one document is not supported"
185192
}
193+
ErrorImpl::RecursionLimitExceeded => "recursion limit exceeded",
186194
}
187195
}
188196

@@ -218,6 +226,7 @@ impl Display for Error {
218226
ErrorImpl::MoreThanOneDocument => f.write_str(
219227
"deserializing from YAML containing more than one document is not supported",
220228
),
229+
ErrorImpl::RecursionLimitExceeded => f.write_str("recursion limit exceeded"),
221230
}
222231
}
223232
}
@@ -241,6 +250,9 @@ impl Debug for Error {
241250
}
242251
ErrorImpl::EndOfStream => formatter.debug_tuple("EndOfStream").finish(),
243252
ErrorImpl::MoreThanOneDocument => formatter.debug_tuple("MoreThanOneDocument").finish(),
253+
ErrorImpl::RecursionLimitExceeded => {
254+
formatter.debug_tuple("RecursionLimitExceeded").finish()
255+
}
244256
}
245257
}
246258
}

tests/test_error.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,51 @@ fn test_invalid_scalar_type() {
257257
let expected = "x: invalid type: unit value, expected an array of length 1 at line 2 column 1";
258258
test_error::<S>(yaml, expected);
259259
}
260+
261+
#[test]
262+
fn test_infinite_recursion_objects() {
263+
#[derive(Deserialize, Debug)]
264+
struct S {
265+
x: Option<Box<S>>,
266+
}
267+
268+
let yaml = "&a {x: *a}";
269+
let expected = "recursion limit exceeded";
270+
test_error::<S>(yaml, expected);
271+
}
272+
273+
#[test]
274+
fn test_infinite_recursion_arrays() {
275+
#[derive(Deserialize, Debug)]
276+
struct S {
277+
x: Option<Box<S>>,
278+
}
279+
280+
let yaml = "&a [*a]";
281+
let expected = "recursion limit exceeded";
282+
test_error::<S>(yaml, expected);
283+
}
284+
285+
#[test]
286+
fn test_finite_recursion_objects() {
287+
#[derive(Deserialize, Debug)]
288+
struct S {
289+
x: Option<Box<S>>,
290+
}
291+
292+
let yaml = "{x:".repeat(1_000) + &"}".repeat(1_000);
293+
let expected = "recursion limit exceeded";
294+
test_error::<i32>(&yaml, expected);
295+
}
296+
297+
#[test]
298+
fn test_finite_recursion_arrays() {
299+
#[derive(Deserialize, Debug)]
300+
struct S {
301+
x: Option<Box<S>>,
302+
}
303+
304+
let yaml = "[".repeat(1_000) + &"]".repeat(1_000);
305+
let expected = "recursion limit exceeded";
306+
test_error::<S>(&yaml, expected);
307+
}

0 commit comments

Comments
 (0)