Skip to content

Commit 2b609ee

Browse files
ilslvtyranron
andauthored
Merge pull request from GHSA-4rx6-g5vg-5f3j
* Replace recursions with heap allocations * Some corrections [skip ci] * Add recursive nested fragments test case * Docs and small corrections * Corrections Co-authored-by: Kai Ren <[email protected]>
1 parent 6d6c71f commit 2b609ee

File tree

8 files changed

+292
-101
lines changed

8 files changed

+292
-101
lines changed

juniper/src/validation/rules/no_fragment_cycles.rs

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,6 @@ use crate::{
77
value::ScalarValue,
88
};
99

10-
pub struct NoFragmentCycles<'a> {
11-
current_fragment: Option<&'a str>,
12-
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
13-
fragment_order: Vec<&'a str>,
14-
}
15-
16-
struct CycleDetector<'a> {
17-
visited: HashSet<&'a str>,
18-
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
19-
path_indices: HashMap<&'a str, usize>,
20-
errors: Vec<RuleError>,
21-
}
22-
2310
pub fn factory<'a>() -> NoFragmentCycles<'a> {
2411
NoFragmentCycles {
2512
current_fragment: None,
@@ -28,6 +15,12 @@ pub fn factory<'a>() -> NoFragmentCycles<'a> {
2815
}
2916
}
3017

18+
pub struct NoFragmentCycles<'a> {
19+
current_fragment: Option<&'a str>,
20+
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
21+
fragment_order: Vec<&'a str>,
22+
}
23+
3124
impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a>
3225
where
3326
S: ScalarValue,
@@ -38,14 +31,12 @@ where
3831
let mut detector = CycleDetector {
3932
visited: HashSet::new(),
4033
spreads: &self.spreads,
41-
path_indices: HashMap::new(),
4234
errors: Vec::new(),
4335
};
4436

4537
for frag in &self.fragment_order {
4638
if !detector.visited.contains(frag) {
47-
let mut path = Vec::new();
48-
detector.detect_from(frag, &mut path);
39+
detector.detect_from(frag);
4940
}
5041
}
5142

@@ -91,19 +82,46 @@ where
9182
}
9283
}
9384

85+
type CycleDetectorState<'a> = (&'a str, Vec<&'a Spanning<&'a str>>, HashMap<&'a str, usize>);
86+
87+
struct CycleDetector<'a> {
88+
visited: HashSet<&'a str>,
89+
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
90+
errors: Vec<RuleError>,
91+
}
92+
9493
impl<'a> CycleDetector<'a> {
95-
fn detect_from(&mut self, from: &'a str, path: &mut Vec<&'a Spanning<&'a str>>) {
94+
fn detect_from(&mut self, from: &'a str) {
95+
let mut to_visit = Vec::new();
96+
to_visit.push((from, Vec::new(), HashMap::new()));
97+
98+
while let Some((from, path, path_indices)) = to_visit.pop() {
99+
to_visit.extend(self.detect_from_inner(from, path, path_indices));
100+
}
101+
}
102+
103+
/// This function should be called only inside [`Self::detect_from()`], as
104+
/// it's a recursive function using heap instead of a stack. So, instead of
105+
/// the recursive call, we return a [`Vec`] that is visited inside
106+
/// [`Self::detect_from()`].
107+
fn detect_from_inner(
108+
&mut self,
109+
from: &'a str,
110+
path: Vec<&'a Spanning<&'a str>>,
111+
mut path_indices: HashMap<&'a str, usize>,
112+
) -> Vec<CycleDetectorState<'a>> {
96113
self.visited.insert(from);
97114

98115
if !self.spreads.contains_key(from) {
99-
return;
116+
return Vec::new();
100117
}
101118

102-
self.path_indices.insert(from, path.len());
119+
path_indices.insert(from, path.len());
103120

121+
let mut to_visit = Vec::new();
104122
for node in &self.spreads[from] {
105-
let name = &node.item;
106-
let index = self.path_indices.get(name).cloned();
123+
let name = node.item;
124+
let index = path_indices.get(name).cloned();
107125

108126
if let Some(index) = index {
109127
let err_pos = if index < path.len() {
@@ -114,14 +132,14 @@ impl<'a> CycleDetector<'a> {
114132

115133
self.errors
116134
.push(RuleError::new(&error_message(name), &[err_pos.start]));
117-
} else if !self.visited.contains(name) {
135+
} else {
136+
let mut path = path.clone();
118137
path.push(node);
119-
self.detect_from(name, path);
120-
path.pop();
138+
to_visit.push((name, path, path_indices.clone()));
121139
}
122140
}
123141

124-
self.path_indices.remove(from);
142+
to_visit
125143
}
126144
}
127145

juniper/src/validation/rules/no_undefined_variables.rs

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ pub enum Scope<'a> {
1212
Fragment(&'a str),
1313
}
1414

15-
pub struct NoUndefinedVariables<'a> {
16-
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
17-
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
18-
current_scope: Option<Scope<'a>>,
19-
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
20-
}
21-
2215
pub fn factory<'a>() -> NoUndefinedVariables<'a> {
2316
NoUndefinedVariables {
2417
defined_variables: HashMap::new(),
@@ -28,6 +21,13 @@ pub fn factory<'a>() -> NoUndefinedVariables<'a> {
2821
}
2922
}
3023

24+
pub struct NoUndefinedVariables<'a> {
25+
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
26+
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
27+
current_scope: Option<Scope<'a>>,
28+
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
29+
}
30+
3131
impl<'a> NoUndefinedVariables<'a> {
3232
fn find_undef_vars(
3333
&'a self,
@@ -36,8 +36,34 @@ impl<'a> NoUndefinedVariables<'a> {
3636
unused: &mut Vec<&'a Spanning<&'a str>>,
3737
visited: &mut HashSet<Scope<'a>>,
3838
) {
39+
let mut to_visit = Vec::new();
40+
if let Some(spreads) = self.find_undef_vars_inner(scope, defined, unused, visited) {
41+
to_visit.push(spreads);
42+
}
43+
while let Some(spreads) = to_visit.pop() {
44+
for spread in spreads {
45+
if let Some(spreads) =
46+
self.find_undef_vars_inner(&Scope::Fragment(spread), defined, unused, visited)
47+
{
48+
to_visit.push(spreads);
49+
}
50+
}
51+
}
52+
}
53+
54+
/// This function should be called only inside [`Self::find_undef_vars()`],
55+
/// as it's a recursive function using heap instead of a stack. So, instead
56+
/// of the recursive call, we return a [`Vec`] that is visited inside
57+
/// [`Self::find_undef_vars()`].
58+
fn find_undef_vars_inner(
59+
&'a self,
60+
scope: &Scope<'a>,
61+
defined: &HashSet<&'a str>,
62+
unused: &mut Vec<&'a Spanning<&'a str>>,
63+
visited: &mut HashSet<Scope<'a>>,
64+
) -> Option<&'a Vec<&'a str>> {
3965
if visited.contains(scope) {
40-
return;
66+
return None;
4167
}
4268

4369
visited.insert(scope.clone());
@@ -50,11 +76,7 @@ impl<'a> NoUndefinedVariables<'a> {
5076
}
5177
}
5278

53-
if let Some(spreads) = self.spreads.get(scope) {
54-
for spread in spreads {
55-
self.find_undef_vars(&Scope::Fragment(spread), defined, unused, visited);
56-
}
57-
}
79+
self.spreads.get(scope)
5880
}
5981
}
6082

juniper/src/validation/rules/no_unused_fragments.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@ pub enum Scope<'a> {
1313
Fragment(&'a str),
1414
}
1515

16-
pub struct NoUnusedFragments<'a> {
17-
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
18-
defined_fragments: HashSet<Spanning<&'a str>>,
19-
current_scope: Option<Scope<'a>>,
20-
}
21-
2216
pub fn factory<'a>() -> NoUnusedFragments<'a> {
2317
NoUnusedFragments {
2418
spreads: HashMap::new(),
@@ -27,22 +21,43 @@ pub fn factory<'a>() -> NoUnusedFragments<'a> {
2721
}
2822
}
2923

24+
pub struct NoUnusedFragments<'a> {
25+
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
26+
defined_fragments: HashSet<Spanning<&'a str>>,
27+
current_scope: Option<Scope<'a>>,
28+
}
29+
3030
impl<'a> NoUnusedFragments<'a> {
31-
fn find_reachable_fragments(&self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
31+
fn find_reachable_fragments(&'a self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
32+
let mut to_visit = Vec::new();
3233
if let Scope::Fragment(name) = *from {
33-
if result.contains(name) {
34-
return;
35-
} else {
36-
result.insert(name);
37-
}
34+
to_visit.push(name);
3835
}
3936

40-
if let Some(spreads) = self.spreads.get(from) {
41-
for spread in spreads {
42-
self.find_reachable_fragments(&Scope::Fragment(spread), result)
37+
while let Some(from) = to_visit.pop() {
38+
if let Some(next) = self.find_reachable_fragments_inner(from, result) {
39+
to_visit.extend(next);
4340
}
4441
}
4542
}
43+
44+
/// This function should be called only inside
45+
/// [`Self::find_reachable_fragments()`], as it's a recursive function using
46+
/// heap instead of a stack. So, instead of the recursive call, we return a
47+
/// [`Vec`] that is visited inside [`Self::find_reachable_fragments()`].
48+
fn find_reachable_fragments_inner(
49+
&'a self,
50+
from: &'a str,
51+
result: &mut HashSet<&'a str>,
52+
) -> Option<&'a Vec<&'a str>> {
53+
if result.contains(from) {
54+
return None;
55+
} else {
56+
result.insert(from);
57+
}
58+
59+
self.spreads.get(&Scope::Fragment(from))
60+
}
4661
}
4762

4863
impl<'a, S> Visitor<'a, S> for NoUnusedFragments<'a>

juniper/src/validation/rules/no_unused_variables.rs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ pub enum Scope<'a> {
1212
Fragment(&'a str),
1313
}
1414

15-
pub struct NoUnusedVariables<'a> {
16-
defined_variables: HashMap<Option<&'a str>, HashSet<&'a Spanning<&'a str>>>,
17-
used_variables: HashMap<Scope<'a>, Vec<&'a str>>,
18-
current_scope: Option<Scope<'a>>,
19-
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
20-
}
21-
2215
pub fn factory<'a>() -> NoUnusedVariables<'a> {
2316
NoUnusedVariables {
2417
defined_variables: HashMap::new(),
@@ -28,16 +21,49 @@ pub fn factory<'a>() -> NoUnusedVariables<'a> {
2821
}
2922
}
3023

24+
pub struct NoUnusedVariables<'a> {
25+
defined_variables: HashMap<Option<&'a str>, HashSet<&'a Spanning<&'a str>>>,
26+
used_variables: HashMap<Scope<'a>, Vec<&'a str>>,
27+
current_scope: Option<Scope<'a>>,
28+
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
29+
}
30+
3131
impl<'a> NoUnusedVariables<'a> {
3232
fn find_used_vars(
33-
&self,
33+
&'a self,
3434
from: &Scope<'a>,
3535
defined: &HashSet<&'a str>,
3636
used: &mut HashSet<&'a str>,
3737
visited: &mut HashSet<Scope<'a>>,
3838
) {
39+
let mut to_visit = Vec::new();
40+
if let Some(spreads) = self.find_used_vars_inner(from, defined, used, visited) {
41+
to_visit.push(spreads);
42+
}
43+
while let Some(spreads) = to_visit.pop() {
44+
for spread in spreads {
45+
if let Some(spreads) =
46+
self.find_used_vars_inner(&Scope::Fragment(spread), defined, used, visited)
47+
{
48+
to_visit.push(spreads);
49+
}
50+
}
51+
}
52+
}
53+
54+
/// This function should be called only inside [`Self::find_used_vars()`],
55+
/// as it's a recursive function using heap instead of a stack. So, instead
56+
/// of the recursive call, we return a [`Vec`] that is visited inside
57+
/// [`Self::find_used_vars()`].
58+
fn find_used_vars_inner(
59+
&'a self,
60+
from: &Scope<'a>,
61+
defined: &HashSet<&'a str>,
62+
used: &mut HashSet<&'a str>,
63+
visited: &mut HashSet<Scope<'a>>,
64+
) -> Option<&'a Vec<&'a str>> {
3965
if visited.contains(from) {
40-
return;
66+
return None;
4167
}
4268

4369
visited.insert(from.clone());
@@ -50,11 +76,7 @@ impl<'a> NoUnusedVariables<'a> {
5076
}
5177
}
5278

53-
if let Some(spreads) = self.spreads.get(from) {
54-
for spread in spreads {
55-
self.find_used_vars(&Scope::Fragment(spread), defined, used, visited);
56-
}
57-
}
79+
self.spreads.get(from)
5880
}
5981
}
6082

0 commit comments

Comments
 (0)