Skip to content

Commit a66c6f4

Browse files
committed
deterministic sorting
1 parent 47a09dd commit a66c6f4

File tree

1 file changed

+93
-65
lines changed

1 file changed

+93
-65
lines changed

rust_core/src/pdag.rs

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use petgraph::visit::Dfs;
55

66
use crate::RustDAG;
77

8-
98
#[derive(Debug, Clone)]
109
pub struct RustPDAG {
1110
pub graph: DiGraph<String, f64>,
@@ -15,6 +14,7 @@ pub struct RustPDAG {
1514
pub undirected_edges: HashSet<(String, String)>,
1615
pub latents: HashSet<String>,
1716
}
17+
1818
impl RustPDAG {
1919
pub fn new() -> Self {
2020
RustPDAG {
@@ -27,9 +27,9 @@ impl RustPDAG {
2727
}
2828
}
2929

30-
/// Get all edges in the graph
30+
/// Get all edges in the graph - DETERMINISTIC
3131
pub fn edges(&self) -> Vec<(String, String)> {
32-
self.graph
32+
let mut edges: Vec<(String, String)> = self.graph
3333
.edge_indices()
3434
.map(|edge_idx| {
3535
let (source, target) = self.graph.edge_endpoints(edge_idx).unwrap();
@@ -38,15 +38,18 @@ impl RustPDAG {
3838
self.reverse_node_map[&target].clone(),
3939
)
4040
})
41-
.collect()
41+
.collect();
42+
edges.sort();
43+
edges
4244
}
4345

4446
/// Get all nodes in the graph
4547
pub fn nodes(&self) -> Vec<String> {
4648
let mut nodes: Vec<String> = self.node_map.keys().cloned().collect();
47-
nodes.sort(); // Sort alphabetically for deterministic order
49+
nodes.sort();
4850
nodes
4951
}
52+
5053
/// Adds a single node to the PDAG.
5154
pub fn add_node(&mut self, node: String, latent: bool) -> Result<(), String> {
5255
if !self.node_map.contains_key(&node) {
@@ -230,19 +233,23 @@ impl RustPDAG {
230233
}
231234
}
232235

233-
/// Returns a subgraph containing only directed edges as a RustDAG.
236+
/// Returns a subgraph containing only directed edges as a RustDAG - DETERMINISTIC
234237
pub fn directed_graph(&self) -> RustDAG {
235238
let mut dag = RustDAG::new();
236239

237-
// Add all nodes with their latent status
238-
for node in self.node_map.keys() {
239-
let is_latent = self.latents.contains(node);
240+
// Add all nodes with their latent status - DETERMINISTIC ORDER
241+
let mut nodes: Vec<String> = self.node_map.keys().cloned().collect();
242+
nodes.sort();
243+
for node in nodes {
244+
let is_latent = self.latents.contains(&node);
240245
dag.add_node(node.clone(), is_latent).unwrap();
241246
}
242247

243248
// Add only directed edges
244-
for (u, v) in &self.directed_edges {
245-
dag.add_edge(u.clone(), v.clone(), None).unwrap();
249+
let mut directed_edges: Vec<(String, String)> = self.directed_edges.iter().cloned().collect();
250+
directed_edges.sort();
251+
for (u, v) in directed_edges {
252+
dag.add_edge(u, v, None).unwrap();
246253
}
247254

248255
dag
@@ -320,40 +327,50 @@ impl RustPDAG {
320327
Ok(false)
321328
}
322329

323-
324330
/// Apply Meek's rules to orient undirected edges
325331
pub fn apply_meeks_rules(&mut self, apply_r4: bool, inplace: bool) -> Result<Option<RustPDAG>, String> {
326-
let mut pdag = if inplace {
327-
self
332+
if inplace {
333+
// Work directly on self
334+
self.apply_meeks_rules_internal(apply_r4)?;
335+
Ok(None)
328336
} else {
329-
&mut self.copy()
330-
};
337+
// Work on a copy
338+
let mut pdag_copy = self.copy();
339+
pdag_copy.apply_meeks_rules_internal(apply_r4)?;
340+
Ok(Some(pdag_copy))
341+
}
342+
}
331343

344+
/// Internal method that applies Meek's rules to the current instance
345+
fn apply_meeks_rules_internal(&mut self, apply_r4: bool) -> Result<(), String> {
332346
let mut changed = true;
333347
while changed {
334348
changed = false;
335-
let nodes: Vec<String> = pdag.nodes();
349+
let nodes: Vec<String> = self.nodes();
336350

337351
// Rule 1: If X -> Y - Z and
338352
// (X not adj Z) and
339353
// (adding Y -> Z doesn't create cycle) and
340354
// (adding Y -> Z doesn't create an unshielded collider) => Y → Z
341355
for y in &nodes {
342-
if !pdag.node_map.contains_key(y) {
356+
if !self.node_map.contains_key(y) {
343357
continue;
344358
}
345-
let directed_parents = pdag.directed_parents(y)?;
346-
let undirected_neighbors = pdag.undirected_neighbors(y)?;
359+
// Convert HashSets to sorted vectors for deterministic iteration
360+
let mut directed_parents: Vec<String> = self.directed_parents(y)?.into_iter().collect();
361+
directed_parents.sort();
362+
let mut undirected_neighbors: Vec<String> = self.undirected_neighbors(y)?.into_iter().collect();
363+
undirected_neighbors.sort();
347364

348365
for x in &directed_parents {
349366
for z in &undirected_neighbors {
350-
if !pdag.is_adjacent(x, z)
351-
&& !pdag.check_new_unshielded_collider(y, z)?
352-
&& !pdag.has_directed_path(z, y)?
367+
if !self.is_adjacent(x, z)
368+
&& !self.check_new_unshielded_collider(y, z)?
369+
&& !self.has_directed_path(z, y)?
353370
{
354371
// Ensure x -> y exists
355-
if pdag.has_directed_edge(x, y) {
356-
if pdag.orient_undirected_edge(y, z, true).is_ok() {
372+
if self.has_directed_edge(x, y) {
373+
if self.orient_undirected_edge(y, z, true).is_ok() {
357374
changed = true;
358375
break;
359376
}
@@ -367,18 +384,21 @@ impl RustPDAG {
367384

368385
// Rule 2: If X -> Z and Z -> Y and X - Y => X -> Y
369386
for z in &nodes {
370-
if !pdag.node_map.contains_key(z) {
387+
if !self.node_map.contains_key(z) {
371388
continue;
372389
}
373-
let parents = pdag.directed_parents(z)?;
374-
let children = pdag.directed_children(z)?;
390+
// Convert HashSets to sorted vectors for deterministic iteration
391+
let mut parents: Vec<String> = self.directed_parents(z)?.into_iter().collect();
392+
parents.sort();
393+
let mut children: Vec<String> = self.directed_children(z)?.into_iter().collect();
394+
children.sort();
375395

376396
for x in &parents {
377397
for y in &children {
378-
if pdag.has_undirected_edge(x, y) {
398+
if self.has_undirected_edge(x, y) {
379399
// Ensure x -> z and z -> y exist
380-
if pdag.has_directed_edge(x, z) && pdag.has_directed_edge(z, y) {
381-
if pdag.orient_undirected_edge(x, y, true).is_ok() {
400+
if self.has_directed_edge(x, z) && self.has_directed_edge(z, y) {
401+
if self.orient_undirected_edge(x, y, true).is_ok() {
382402
changed = true;
383403
break;
384404
}
@@ -390,12 +410,14 @@ impl RustPDAG {
390410
if changed { break; }
391411
}
392412

393-
// Rule 3
413+
// Rule 3: If X - Y, X - Z, X - W and Y -> W, Z -> W => X -> W
394414
for x in &nodes {
395-
if !pdag.node_map.contains_key(x) {
415+
if !self.node_map.contains_key(x) {
396416
continue;
397417
}
398-
let undirected_nbs: Vec<String> = pdag.undirected_neighbors(x)?.into_iter().collect();
418+
// Convert HashSet to sorted vector for deterministic iteration
419+
let mut undirected_nbs: Vec<String> = self.undirected_neighbors(x)?.into_iter().collect();
420+
undirected_nbs.sort();
399421

400422
if undirected_nbs.len() < 3 {
401423
continue;
@@ -406,8 +428,8 @@ impl RustPDAG {
406428
for k in (j + 1)..undirected_nbs.len() {
407429
let (y, z, w) = (&undirected_nbs[i], &undirected_nbs[j], &undirected_nbs[k]);
408430

409-
if pdag.has_directed_edge(y, w) && pdag.has_directed_edge(z, w) {
410-
if pdag.orient_undirected_edge(x, w, true).is_ok() {
431+
if self.has_directed_edge(y, w) && self.has_directed_edge(z, w) {
432+
if self.orient_undirected_edge(x, w, true).is_ok() {
411433
changed = true;
412434
break;
413435
}
@@ -423,25 +445,31 @@ impl RustPDAG {
423445
// Rule 4
424446
if apply_r4 {
425447
for c in &nodes {
426-
if !pdag.node_map.contains_key(c) {
448+
if !self.node_map.contains_key(c) {
427449
continue;
428450
}
429-
let children = pdag.directed_children(c)?;
430-
let parents = pdag.directed_parents(c)?;
451+
452+
let mut children: Vec<String> = self.directed_children(c)?.into_iter().collect();
453+
children.sort();
454+
let mut parents: Vec<String> = self.directed_parents(c)?.into_iter().collect();
455+
parents.sort();
431456

432457
for b in &children {
433458
for d in &parents {
434-
if b == d || pdag.is_adjacent(b, d) {
459+
if b == d || self.is_adjacent(b, d) {
435460
continue;
436461
}
437462

438-
let b_undirected = pdag.undirected_neighbors(b)?;
439-
let c_neighbors = pdag.all_neighbors(c)?;
440-
let d_undirected = pdag.undirected_neighbors(d)?;
463+
let mut b_undirected: Vec<String> = self.undirected_neighbors(b)?.into_iter().collect();
464+
b_undirected.sort();
465+
let mut c_neighbors: Vec<String> = self.all_neighbors(c)?.into_iter().collect();
466+
c_neighbors.sort();
467+
let mut d_undirected: Vec<String> = self.undirected_neighbors(d)?.into_iter().collect();
468+
d_undirected.sort();
441469

442470
for a in &b_undirected {
443471
if c_neighbors.contains(a) && d_undirected.contains(a) {
444-
if pdag.orient_undirected_edge(a, b, true).is_ok() {
472+
if self.orient_undirected_edge(a, b, true).is_ok() {
445473
changed = true;
446474
break;
447475
}
@@ -456,14 +484,9 @@ impl RustPDAG {
456484
}
457485
}
458486

459-
if inplace {
460-
Ok(None)
461-
} else {
462-
Ok(Some(pdag.clone()))
463-
}
487+
Ok(())
464488
}
465489

466-
467490
pub fn to_dag(&self) -> Result<RustDAG, String> {
468491
let mut dag = RustDAG::new();
469492

@@ -473,25 +496,29 @@ impl RustPDAG {
473496
dag.add_node(node.clone(), is_latent)?;
474497
}
475498

476-
// Add all directed edges
477-
for (u, v) in &self.directed_edges {
478-
dag.add_edge(u.clone(), v.clone(), None)?;
499+
// Add all directed edg
500+
let mut directed_edges_sorted: Vec<(String, String)> = self.directed_edges.iter().cloned().collect();
501+
directed_edges_sorted.sort();
502+
for (u, v) in directed_edges_sorted {
503+
dag.add_edge(u, v, None)?;
479504
}
480505

481506
let mut pdag_copy = self.copy();
482507

483508
// Add undirected edges to dag before node removal
484-
for (u, v) in &self.undirected_edges {
485-
if !dag.has_edge(u, v) && !dag.has_edge(v, u) {
509+
let mut undirected_edges_sorted: Vec<(String, String)> = self.undirected_edges.iter().cloned().collect();
510+
undirected_edges_sorted.sort();
511+
for (u, v) in undirected_edges_sorted {
512+
if !dag.has_edge(&u, &v) && !dag.has_edge(&v, &u) {
486513
// Try adding u -> v, if it creates cycle, add v -> u
487514
if dag.add_edge(u.clone(), v.clone(), None).is_err() {
488-
dag.add_edge(v.clone(), u.clone(), None)?;
515+
dag.add_edge(v, u, None)?;
489516
}
490517
}
491518
}
492519

493520
while !pdag_copy.nodes().is_empty() {
494-
let nodes: Vec<String> = pdag_copy.nodes(); // Get fresh node list
521+
let nodes: Vec<String> = pdag_copy.nodes();
495522
let mut found = false;
496523

497524
for x in &nodes {
@@ -502,8 +529,10 @@ impl RustPDAG {
502529

503530
// Find nodes with no directed outgoing edges
504531
let directed_children = pdag_copy.directed_children(x)?;
505-
let undirected_neighbors = pdag_copy.undirected_neighbors(x)?;
506-
let directed_parents = pdag_copy.directed_parents(x)?;
532+
let mut undirected_neighbors: Vec<String> = pdag_copy.undirected_neighbors(x)?.into_iter().collect();
533+
undirected_neighbors.sort();
534+
let mut directed_parents: Vec<String> = pdag_copy.directed_parents(x)?.into_iter().collect();
535+
directed_parents.sort();
507536

508537
// Check if undirected neighbors + parents form a clique
509538
let mut neighbors_are_clique = true;
@@ -521,7 +550,8 @@ impl RustPDAG {
521550
found = true;
522551

523552
// Add all incoming edges to DAG
524-
let all_predecessors = pdag_copy.all_neighbors(x)?;
553+
let mut all_predecessors: Vec<String> = pdag_copy.all_neighbors(x)?.into_iter().collect();
554+
all_predecessors.sort();
525555
for y in &all_predecessors {
526556
if pdag_copy.is_adjacent(y, x) && !dag.has_edge(x, y) {
527557
dag.add_edge(y.clone(), x.clone(), None)?;
@@ -536,7 +566,8 @@ impl RustPDAG {
536566

537567
if !found {
538568
// Handle remaining edges arbitrarily, ensuring no cycles
539-
let remaining_edges: Vec<(String, String)> = pdag_copy.undirected_edges.iter().cloned().collect();
569+
let mut remaining_edges: Vec<(String, String)> = pdag_copy.undirected_edges.iter().cloned().collect();
570+
remaining_edges.sort(); // Deterministic order
540571
for (u, v) in remaining_edges {
541572
if pdag_copy.node_map.contains_key(&u) && pdag_copy.node_map.contains_key(&v) && !dag.has_edge(&v, &u) {
542573
if let Ok(()) = dag.add_edge(u.clone(), v.clone(), None) {
@@ -579,7 +610,4 @@ impl RustPDAG {
579610

580611
Ok(())
581612
}
582-
583-
584-
585-
}
613+
}

0 commit comments

Comments
 (0)