Skip to content

Commit 7fc6689

Browse files
committed
Support custom recursion limits at build time
1 parent 1e93f56 commit 7fc6689

File tree

13 files changed

+119
-29
lines changed

13 files changed

+119
-29
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ configured with the required dependencies to compile the whole project.
467467
- `std`: Enable integration with standard library. Disable this feature for `no_std` support. This feature is enabled by default.
468468
- `derive`: Enable integration with `prost-derive`. Disable this feature to reduce compile times. This feature is enabled by default.
469469
- `prost-derive`: Deprecated. Alias for `derive` feature.
470-
- `no-recursion-limit`: Disable the recursion limit. The recursion limit is 100 and cannot be customized.
470+
- `no-recursion-limit`: Disable the recursion limit. The recursion limit is 100, and can be changed with `prost_build::recursion_limit`.
471471

472472
## FAQ
473473

prost-build/src/code_generator.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ impl<'b> CodeGenerator<'_, 'b> {
237237
self.context.prost_path()
238238
));
239239
self.append_skip_debug(&fq_message_name);
240+
self.append_recursion_limit(&fq_message_name);
240241
self.push_indent();
241242
self.buf.push_str("pub struct ");
242243
self.buf.push_str(&to_upper_camel(&message_name));
@@ -377,6 +378,21 @@ impl<'b> CodeGenerator<'_, 'b> {
377378
}
378379
}
379380

381+
fn append_recursion_limit(&mut self, fq_message_name: &str) {
382+
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
383+
if let Some(limit) = self
384+
.config()
385+
.recursion_limits
386+
.get_first(fq_message_name)
387+
.cloned()
388+
{
389+
push_indent(self.buf, self.depth);
390+
self.buf
391+
.push_str(&format!("#[prost(recursion_limit={})]", limit));
392+
self.buf.push('\n');
393+
}
394+
}
395+
380396
fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) {
381397
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
382398
for attribute in self.context.field_attributes(fq_message_name, field_name) {

prost-build/src/config.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub struct Config {
5050
pub(crate) skip_debug: PathMap<()>,
5151
pub(crate) skip_protoc_run: bool,
5252
pub(crate) skip_source_info: bool,
53+
pub(crate) recursion_limits: PathMap<u32>,
5354
pub(crate) include_file: Option<PathBuf>,
5455
pub(crate) prost_path: Option<String>,
5556
#[cfg(feature = "format")]
@@ -1030,6 +1031,25 @@ impl Config {
10301031
self.compile_fds(file_descriptor_set)
10311032
}
10321033

1034+
/// Configure a custom recursion limit for certain messages.
1035+
///
1036+
/// This defaults to 100, and can be disabled with the no-recursion-limit crate feature.
1037+
///
1038+
/// # Example
1039+
///
1040+
/// ```rust
1041+
/// # let mut config = prost_build::Config::new();
1042+
/// config.recursion_limit("my_messages.MyMessageType", 1000);
1043+
/// ```
1044+
pub fn recursion_limit<P>(&mut self, path: P, limit: u32) -> &mut Self
1045+
where
1046+
P: AsRef<str>,
1047+
{
1048+
self.recursion_limits
1049+
.insert(path.as_ref().to_string(), limit);
1050+
self
1051+
}
1052+
10331053
pub(crate) fn write_includes(
10341054
&self,
10351055
mut modules: Vec<&Module>,
@@ -1192,6 +1212,7 @@ impl default::Default for Config {
11921212
skip_debug: PathMap::default(),
11931213
skip_protoc_run: false,
11941214
skip_source_info: false,
1215+
recursion_limits: PathMap::default(),
11951216
include_file: None,
11961217
prost_path: None,
11971218
#[cfg(feature = "format")]

prost-derive/src/field/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ impl fmt::Display for Label {
225225
}
226226

227227
/// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`.
228-
fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> {
228+
pub(super) fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> {
229229
let mut result = Vec::new();
230230
for attr in attrs.iter() {
231231
if let Meta::List(meta_list) = &attr.meta {

prost-derive/src/lib.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,36 @@ use proc_macro2::{Span, TokenStream};
1111
use quote::quote;
1212
use syn::{
1313
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
14-
FieldsUnnamed, Ident, Index, Variant,
14+
FieldsUnnamed, Ident, Index, Lit, Meta, Variant,
1515
};
1616

1717
mod field;
1818
use crate::field::Field;
1919

2020
fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
2121
let input: DeriveInput = syn::parse2(input)?;
22-
2322
let ident = input.ident;
2423

2524
syn::custom_keyword!(skip_debug);
2625
let skip_debug = input
2726
.attrs
28-
.into_iter()
27+
.iter()
2928
.any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
3029

30+
let mut recursion_limit: u32 = 100;
31+
for attr in field::prost_attrs(input.attrs.clone())? {
32+
match attr {
33+
Meta::NameValue(ref meta) if meta.path.is_ident("recursion_limit") => {
34+
let Expr::Lit(ref lit) = meta.value else {
35+
continue;
36+
};
37+
let Lit::Int(ref int) = lit.lit else { continue };
38+
recursion_limit = int.base10_parse()?;
39+
}
40+
_ => (),
41+
}
42+
}
43+
3144
let variant_data = match input.data {
3245
Data::Struct(variant_data) => variant_data,
3346
Data::Enum(..) => bail!("Message can not be derived for an enum"),
@@ -203,6 +216,10 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
203216
fn clear(&mut self) {
204217
#(#clear;)*
205218
}
219+
220+
fn recursion_limit() -> u32 {
221+
#recursion_limit
222+
}
206223
}
207224

208225
impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {

prost/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ configured with the required dependencies to compile the whole project.
467467
- `std`: Enable integration with standard library. Disable this feature for `no_std` support. This feature is enabled by default.
468468
- `derive`: Enable integration with `prost-derive`. Disable this feature to reduce compile times. This feature is enabled by default.
469469
- `prost-derive`: Deprecated. Alias for `derive` feature.
470-
- `no-recursion-limit`: Disable the recursion limit. The recursion limit is 100 and cannot be customized.
470+
- `no-recursion-limit`: Disable the recursion limit. The recursion limit is 100, and can be changed with `prost_build::recursion_limit`.
471471

472472
## FAQ
473473

prost/src/encoding.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,27 @@ pub struct DecodeContext {
3737
/// How many times we can recurse in the current decode stack before we hit
3838
/// the recursion limit.
3939
///
40-
/// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
41-
/// customized. The recursion limit can be ignored by building the Prost
42-
/// crate with the `no-recursion-limit` feature.
40+
/// It defaults to 100 and can be changed using `prost_build::recursion_limit`,
41+
/// or it can be disabled entirely using the `no-recursion-limit` feature.
4342
#[cfg(not(feature = "no-recursion-limit"))]
44-
recurse_count: u32,
43+
#[doc(hidden)]
44+
pub recurse_count: u32,
4545
}
4646

47-
#[cfg(not(feature = "no-recursion-limit"))]
48-
impl Default for DecodeContext {
49-
#[inline]
50-
fn default() -> DecodeContext {
47+
impl DecodeContext {
48+
#[allow(unused_variables)]
49+
pub fn new(recursion_limit: u32) -> DecodeContext {
5150
DecodeContext {
52-
recurse_count: crate::RECURSION_LIMIT,
51+
#[cfg(not(feature = "no-recursion-limit"))]
52+
recurse_count: recursion_limit,
5353
}
5454
}
55-
}
5655

57-
impl DecodeContext {
5856
/// Call this function before recursively decoding.
5957
///
6058
/// There is no `exit` function since this function creates a new `DecodeContext`
6159
/// to be used at the next level of recursion. Continue to use the old context
62-
// at the previous level of recursion.
60+
/// at the previous level of recursion.
6361
#[cfg(not(feature = "no-recursion-limit"))]
6462
#[inline]
6563
pub(crate) fn enter_recursion(&self) -> DecodeContext {
@@ -1225,7 +1223,7 @@ mod test {
12251223
wire_type,
12261224
&mut roundtrip_value,
12271225
&mut buf,
1228-
DecodeContext::default(),
1226+
DecodeContext::new(100),
12291227
)
12301228
.map_err(|error| TestCaseError::fail(error.to_string()))?;
12311229

@@ -1297,7 +1295,7 @@ mod test {
12971295
wire_type,
12981296
&mut roundtrip_value,
12991297
&mut buf,
1300-
DecodeContext::default(),
1298+
DecodeContext::new(100),
13011299
)
13021300
.map_err(|error| TestCaseError::fail(error.to_string()))?;
13031301
}
@@ -1316,7 +1314,7 @@ mod test {
13161314
WireType::LengthDelimited,
13171315
&mut s,
13181316
&mut &buf[..],
1319-
DecodeContext::default(),
1317+
DecodeContext::new(100),
13201318
);
13211319
r.expect_err("must be an error");
13221320
assert!(s.is_empty());

prost/src/lib.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@ pub use crate::error::{DecodeError, EncodeError, UnknownEnumValue};
2424
pub use crate::message::Message;
2525
pub use crate::name::Name;
2626

27-
// See `encoding::DecodeContext` for more info.
28-
// 100 is the default recursion limit in the C++ implementation.
29-
#[cfg(not(feature = "no-recursion-limit"))]
30-
const RECURSION_LIMIT: u32 = 100;
31-
3227
// Re-export #[derive(Message, Enumeration, Oneof)].
3328
// Based on serde's equivalent re-export [1], but enabled by default.
3429
//

prost/src/message.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ pub trait Message: Send + Sync {
127127
where
128128
Self: Sized,
129129
{
130-
let ctx = DecodeContext::default();
130+
let ctx = DecodeContext::new(Self::recursion_limit());
131131
while buf.has_remaining() {
132132
let (tag, wire_type) = decode_key(&mut buf)?;
133133
self.merge_field(tag, wire_type, &mut buf, ctx.clone())?;
@@ -145,12 +145,22 @@ pub trait Message: Send + Sync {
145145
WireType::LengthDelimited,
146146
self,
147147
&mut buf,
148-
DecodeContext::default(),
148+
DecodeContext::new(Self::recursion_limit()),
149149
)
150150
}
151151

152152
/// Clears the message, resetting all fields to their default.
153153
fn clear(&mut self);
154+
155+
/// The recursion limit for decoding protobuf messages.
156+
///
157+
/// Defaults to 100. Can be customized in your build.rs or by using the no-recursion-limit crate feature.
158+
fn recursion_limit() -> u32
159+
where
160+
Self: Sized,
161+
{
162+
100
163+
}
154164
}
155165

156166
impl<M> Message for Box<M>
@@ -175,6 +185,9 @@ where
175185
fn clear(&mut self) {
176186
(**self).clear()
177187
}
188+
fn recursion_limit() -> u32 {
189+
M::recursion_limit()
190+
}
178191
}
179192

180193
#[cfg(test)]

tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ fn main() {
4848

4949
prost_build::Config::new()
5050
.btree_map(["."])
51+
.recursion_limit("nesting.E", 200)
5152
.compile_protos(&[src.join("nesting.proto")], includes)
5253
.unwrap();
5354

0 commit comments

Comments
 (0)