@@ -5,7 +5,6 @@ use petgraph::visit::Dfs;
5
5
6
6
use crate :: RustDAG ;
7
7
8
-
9
8
#[ derive( Debug , Clone ) ]
10
9
pub struct RustPDAG {
11
10
pub graph : DiGraph < String , f64 > ,
@@ -15,6 +14,7 @@ pub struct RustPDAG {
15
14
pub undirected_edges : HashSet < ( String , String ) > ,
16
15
pub latents : HashSet < String > ,
17
16
}
17
+
18
18
impl RustPDAG {
19
19
pub fn new ( ) -> Self {
20
20
RustPDAG {
@@ -27,9 +27,9 @@ impl RustPDAG {
27
27
}
28
28
}
29
29
30
- /// Get all edges in the graph
30
+ /// Get all edges in the graph - DETERMINISTIC
31
31
pub fn edges ( & self ) -> Vec < ( String , String ) > {
32
- self . graph
32
+ let mut edges : Vec < ( String , String ) > = self . graph
33
33
. edge_indices ( )
34
34
. map ( |edge_idx| {
35
35
let ( source, target) = self . graph . edge_endpoints ( edge_idx) . unwrap ( ) ;
@@ -38,15 +38,18 @@ impl RustPDAG {
38
38
self . reverse_node_map [ & target] . clone ( ) ,
39
39
)
40
40
} )
41
- . collect ( )
41
+ . collect ( ) ;
42
+ edges. sort ( ) ;
43
+ edges
42
44
}
43
45
44
46
/// Get all nodes in the graph
45
47
pub fn nodes ( & self ) -> Vec < String > {
46
48
let mut nodes: Vec < String > = self . node_map . keys ( ) . cloned ( ) . collect ( ) ;
47
- nodes. sort ( ) ; // Sort alphabetically for deterministic order
49
+ nodes. sort ( ) ;
48
50
nodes
49
51
}
52
+
50
53
/// Adds a single node to the PDAG.
51
54
pub fn add_node ( & mut self , node : String , latent : bool ) -> Result < ( ) , String > {
52
55
if !self . node_map . contains_key ( & node) {
@@ -230,19 +233,23 @@ impl RustPDAG {
230
233
}
231
234
}
232
235
233
- /// Returns a subgraph containing only directed edges as a RustDAG.
236
+ /// Returns a subgraph containing only directed edges as a RustDAG - DETERMINISTIC
234
237
pub fn directed_graph ( & self ) -> RustDAG {
235
238
let mut dag = RustDAG :: new ( ) ;
236
239
237
- // Add all nodes with their latent status
238
- for node in self . node_map . keys ( ) {
239
- let is_latent = self . latents . contains ( node) ;
240
+ // Add all nodes with their latent status - DETERMINISTIC ORDER
241
+ let mut nodes: Vec < String > = self . node_map . keys ( ) . cloned ( ) . collect ( ) ;
242
+ nodes. sort ( ) ;
243
+ for node in nodes {
244
+ let is_latent = self . latents . contains ( & node) ;
240
245
dag. add_node ( node. clone ( ) , is_latent) . unwrap ( ) ;
241
246
}
242
247
243
248
// Add only directed edges
244
- for ( u, v) in & self . directed_edges {
245
- dag. add_edge ( u. clone ( ) , v. clone ( ) , None ) . unwrap ( ) ;
249
+ let mut directed_edges: Vec < ( String , String ) > = self . directed_edges . iter ( ) . cloned ( ) . collect ( ) ;
250
+ directed_edges. sort ( ) ;
251
+ for ( u, v) in directed_edges {
252
+ dag. add_edge ( u, v, None ) . unwrap ( ) ;
246
253
}
247
254
248
255
dag
@@ -320,40 +327,50 @@ impl RustPDAG {
320
327
Ok ( false )
321
328
}
322
329
323
-
324
330
/// Apply Meek's rules to orient undirected edges
325
331
pub fn apply_meeks_rules ( & mut self , apply_r4 : bool , inplace : bool ) -> Result < Option < RustPDAG > , String > {
326
- let mut pdag = if inplace {
327
- self
332
+ if inplace {
333
+ // Work directly on self
334
+ self . apply_meeks_rules_internal ( apply_r4) ?;
335
+ Ok ( None )
328
336
} else {
329
- & mut self . copy ( )
330
- } ;
337
+ // Work on a copy
338
+ let mut pdag_copy = self . copy ( ) ;
339
+ pdag_copy. apply_meeks_rules_internal ( apply_r4) ?;
340
+ Ok ( Some ( pdag_copy) )
341
+ }
342
+ }
331
343
344
+ /// Internal method that applies Meek's rules to the current instance
345
+ fn apply_meeks_rules_internal ( & mut self , apply_r4 : bool ) -> Result < ( ) , String > {
332
346
let mut changed = true ;
333
347
while changed {
334
348
changed = false ;
335
- let nodes: Vec < String > = pdag . nodes ( ) ;
349
+ let nodes: Vec < String > = self . nodes ( ) ;
336
350
337
351
// Rule 1: If X -> Y - Z and
338
352
// (X not adj Z) and
339
353
// (adding Y -> Z doesn't create cycle) and
340
354
// (adding Y -> Z doesn't create an unshielded collider) => Y → Z
341
355
for y in & nodes {
342
- if !pdag . node_map . contains_key ( y) {
356
+ if !self . node_map . contains_key ( y) {
343
357
continue ;
344
358
}
345
- let directed_parents = pdag. directed_parents ( y) ?;
346
- let undirected_neighbors = pdag. undirected_neighbors ( y) ?;
359
+ // Convert HashSets to sorted vectors for deterministic iteration
360
+ let mut directed_parents: Vec < String > = self . directed_parents ( y) ?. into_iter ( ) . collect ( ) ;
361
+ directed_parents. sort ( ) ;
362
+ let mut undirected_neighbors: Vec < String > = self . undirected_neighbors ( y) ?. into_iter ( ) . collect ( ) ;
363
+ undirected_neighbors. sort ( ) ;
347
364
348
365
for x in & directed_parents {
349
366
for z in & undirected_neighbors {
350
- if !pdag . is_adjacent ( x, z)
351
- && !pdag . check_new_unshielded_collider ( y, z) ?
352
- && !pdag . has_directed_path ( z, y) ?
367
+ if !self . is_adjacent ( x, z)
368
+ && !self . check_new_unshielded_collider ( y, z) ?
369
+ && !self . has_directed_path ( z, y) ?
353
370
{
354
371
// Ensure x -> y exists
355
- if pdag . has_directed_edge ( x, y) {
356
- if pdag . orient_undirected_edge ( y, z, true ) . is_ok ( ) {
372
+ if self . has_directed_edge ( x, y) {
373
+ if self . orient_undirected_edge ( y, z, true ) . is_ok ( ) {
357
374
changed = true ;
358
375
break ;
359
376
}
@@ -367,18 +384,21 @@ impl RustPDAG {
367
384
368
385
// Rule 2: If X -> Z and Z -> Y and X - Y => X -> Y
369
386
for z in & nodes {
370
- if !pdag . node_map . contains_key ( z) {
387
+ if !self . node_map . contains_key ( z) {
371
388
continue ;
372
389
}
373
- let parents = pdag. directed_parents ( z) ?;
374
- let children = pdag. directed_children ( z) ?;
390
+ // Convert HashSets to sorted vectors for deterministic iteration
391
+ let mut parents: Vec < String > = self . directed_parents ( z) ?. into_iter ( ) . collect ( ) ;
392
+ parents. sort ( ) ;
393
+ let mut children: Vec < String > = self . directed_children ( z) ?. into_iter ( ) . collect ( ) ;
394
+ children. sort ( ) ;
375
395
376
396
for x in & parents {
377
397
for y in & children {
378
- if pdag . has_undirected_edge ( x, y) {
398
+ if self . has_undirected_edge ( x, y) {
379
399
// Ensure x -> z and z -> y exist
380
- if pdag . has_directed_edge ( x, z) && pdag . has_directed_edge ( z, y) {
381
- if pdag . orient_undirected_edge ( x, y, true ) . is_ok ( ) {
400
+ if self . has_directed_edge ( x, z) && self . has_directed_edge ( z, y) {
401
+ if self . orient_undirected_edge ( x, y, true ) . is_ok ( ) {
382
402
changed = true ;
383
403
break ;
384
404
}
@@ -390,12 +410,14 @@ impl RustPDAG {
390
410
if changed { break ; }
391
411
}
392
412
393
- // Rule 3
413
+ // Rule 3: If X - Y, X - Z, X - W and Y -> W, Z -> W => X -> W
394
414
for x in & nodes {
395
- if !pdag . node_map . contains_key ( x) {
415
+ if !self . node_map . contains_key ( x) {
396
416
continue ;
397
417
}
398
- let undirected_nbs: Vec < String > = pdag. undirected_neighbors ( x) ?. into_iter ( ) . collect ( ) ;
418
+ // Convert HashSet to sorted vector for deterministic iteration
419
+ let mut undirected_nbs: Vec < String > = self . undirected_neighbors ( x) ?. into_iter ( ) . collect ( ) ;
420
+ undirected_nbs. sort ( ) ;
399
421
400
422
if undirected_nbs. len ( ) < 3 {
401
423
continue ;
@@ -406,8 +428,8 @@ impl RustPDAG {
406
428
for k in ( j + 1 ) ..undirected_nbs. len ( ) {
407
429
let ( y, z, w) = ( & undirected_nbs[ i] , & undirected_nbs[ j] , & undirected_nbs[ k] ) ;
408
430
409
- if pdag . has_directed_edge ( y, w) && pdag . has_directed_edge ( z, w) {
410
- if pdag . orient_undirected_edge ( x, w, true ) . is_ok ( ) {
431
+ if self . has_directed_edge ( y, w) && self . has_directed_edge ( z, w) {
432
+ if self . orient_undirected_edge ( x, w, true ) . is_ok ( ) {
411
433
changed = true ;
412
434
break ;
413
435
}
@@ -423,25 +445,31 @@ impl RustPDAG {
423
445
// Rule 4
424
446
if apply_r4 {
425
447
for c in & nodes {
426
- if !pdag . node_map . contains_key ( c) {
448
+ if !self . node_map . contains_key ( c) {
427
449
continue ;
428
450
}
429
- let children = pdag. directed_children ( c) ?;
430
- let parents = pdag. directed_parents ( c) ?;
451
+
452
+ let mut children: Vec < String > = self . directed_children ( c) ?. into_iter ( ) . collect ( ) ;
453
+ children. sort ( ) ;
454
+ let mut parents: Vec < String > = self . directed_parents ( c) ?. into_iter ( ) . collect ( ) ;
455
+ parents. sort ( ) ;
431
456
432
457
for b in & children {
433
458
for d in & parents {
434
- if b == d || pdag . is_adjacent ( b, d) {
459
+ if b == d || self . is_adjacent ( b, d) {
435
460
continue ;
436
461
}
437
462
438
- let b_undirected = pdag. undirected_neighbors ( b) ?;
439
- let c_neighbors = pdag. all_neighbors ( c) ?;
440
- let d_undirected = pdag. undirected_neighbors ( d) ?;
463
+ let mut b_undirected: Vec < String > = self . undirected_neighbors ( b) ?. into_iter ( ) . collect ( ) ;
464
+ b_undirected. sort ( ) ;
465
+ let mut c_neighbors: Vec < String > = self . all_neighbors ( c) ?. into_iter ( ) . collect ( ) ;
466
+ c_neighbors. sort ( ) ;
467
+ let mut d_undirected: Vec < String > = self . undirected_neighbors ( d) ?. into_iter ( ) . collect ( ) ;
468
+ d_undirected. sort ( ) ;
441
469
442
470
for a in & b_undirected {
443
471
if c_neighbors. contains ( a) && d_undirected. contains ( a) {
444
- if pdag . orient_undirected_edge ( a, b, true ) . is_ok ( ) {
472
+ if self . orient_undirected_edge ( a, b, true ) . is_ok ( ) {
445
473
changed = true ;
446
474
break ;
447
475
}
@@ -456,14 +484,9 @@ impl RustPDAG {
456
484
}
457
485
}
458
486
459
- if inplace {
460
- Ok ( None )
461
- } else {
462
- Ok ( Some ( pdag. clone ( ) ) )
463
- }
487
+ Ok ( ( ) )
464
488
}
465
489
466
-
467
490
pub fn to_dag ( & self ) -> Result < RustDAG , String > {
468
491
let mut dag = RustDAG :: new ( ) ;
469
492
@@ -473,25 +496,29 @@ impl RustPDAG {
473
496
dag. add_node ( node. clone ( ) , is_latent) ?;
474
497
}
475
498
476
- // Add all directed edges
477
- for ( u, v) in & self . directed_edges {
478
- dag. add_edge ( u. clone ( ) , v. clone ( ) , None ) ?;
499
+ // Add all directed edg
500
+ let mut directed_edges_sorted: Vec < ( String , String ) > = self . directed_edges . iter ( ) . cloned ( ) . collect ( ) ;
501
+ directed_edges_sorted. sort ( ) ;
502
+ for ( u, v) in directed_edges_sorted {
503
+ dag. add_edge ( u, v, None ) ?;
479
504
}
480
505
481
506
let mut pdag_copy = self . copy ( ) ;
482
507
483
508
// Add undirected edges to dag before node removal
484
- for ( u, v) in & self . undirected_edges {
485
- if !dag. has_edge ( u, v) && !dag. has_edge ( v, u) {
509
+ let mut undirected_edges_sorted: Vec < ( String , String ) > = self . undirected_edges . iter ( ) . cloned ( ) . collect ( ) ;
510
+ undirected_edges_sorted. sort ( ) ;
511
+ for ( u, v) in undirected_edges_sorted {
512
+ if !dag. has_edge ( & u, & v) && !dag. has_edge ( & v, & u) {
486
513
// Try adding u -> v, if it creates cycle, add v -> u
487
514
if dag. add_edge ( u. clone ( ) , v. clone ( ) , None ) . is_err ( ) {
488
- dag. add_edge ( v. clone ( ) , u. clone ( ) , None ) ?;
515
+ dag. add_edge ( v, u, None ) ?;
489
516
}
490
517
}
491
518
}
492
519
493
520
while !pdag_copy. nodes ( ) . is_empty ( ) {
494
- let nodes: Vec < String > = pdag_copy. nodes ( ) ; // Get fresh node list
521
+ let nodes: Vec < String > = pdag_copy. nodes ( ) ;
495
522
let mut found = false ;
496
523
497
524
for x in & nodes {
@@ -502,8 +529,10 @@ impl RustPDAG {
502
529
503
530
// Find nodes with no directed outgoing edges
504
531
let directed_children = pdag_copy. directed_children ( x) ?;
505
- let undirected_neighbors = pdag_copy. undirected_neighbors ( x) ?;
506
- let directed_parents = pdag_copy. directed_parents ( x) ?;
532
+ let mut undirected_neighbors: Vec < String > = pdag_copy. undirected_neighbors ( x) ?. into_iter ( ) . collect ( ) ;
533
+ undirected_neighbors. sort ( ) ;
534
+ let mut directed_parents: Vec < String > = pdag_copy. directed_parents ( x) ?. into_iter ( ) . collect ( ) ;
535
+ directed_parents. sort ( ) ;
507
536
508
537
// Check if undirected neighbors + parents form a clique
509
538
let mut neighbors_are_clique = true ;
@@ -521,7 +550,8 @@ impl RustPDAG {
521
550
found = true ;
522
551
523
552
// Add all incoming edges to DAG
524
- let all_predecessors = pdag_copy. all_neighbors ( x) ?;
553
+ let mut all_predecessors: Vec < String > = pdag_copy. all_neighbors ( x) ?. into_iter ( ) . collect ( ) ;
554
+ all_predecessors. sort ( ) ;
525
555
for y in & all_predecessors {
526
556
if pdag_copy. is_adjacent ( y, x) && !dag. has_edge ( x, y) {
527
557
dag. add_edge ( y. clone ( ) , x. clone ( ) , None ) ?;
@@ -536,7 +566,8 @@ impl RustPDAG {
536
566
537
567
if !found {
538
568
// Handle remaining edges arbitrarily, ensuring no cycles
539
- let remaining_edges: Vec < ( String , String ) > = pdag_copy. undirected_edges . iter ( ) . cloned ( ) . collect ( ) ;
569
+ let mut remaining_edges: Vec < ( String , String ) > = pdag_copy. undirected_edges . iter ( ) . cloned ( ) . collect ( ) ;
570
+ remaining_edges. sort ( ) ; // Deterministic order
540
571
for ( u, v) in remaining_edges {
541
572
if pdag_copy. node_map . contains_key ( & u) && pdag_copy. node_map . contains_key ( & v) && !dag. has_edge ( & v, & u) {
542
573
if let Ok ( ( ) ) = dag. add_edge ( u. clone ( ) , v. clone ( ) , None ) {
@@ -579,7 +610,4 @@ impl RustPDAG {
579
610
580
611
Ok ( ( ) )
581
612
}
582
-
583
-
584
-
585
- }
613
+ }
0 commit comments