Skip to content

Commit 558b0cb

Browse files
authored
Core Implementaion: d separation & Independencies class(closure, reduce, SG axioms) (#5)
* rust initial commit * update some rust * add python sdk * added nodes and edges fetch in rust * add additional rust implementations * Rename crate to causalgraphs and wire up Python bindings * Add initial implementation of causalgraphs with RustDAG and Python bindings * add requirements.txt * update README.md * conditional compilation: create python, js cargo features * fix get random wasm compilation * fix wasm compilation issue - removed the faulty default python flag - does not compile python code during wasm compilation now * js runtime tests * Refactor project structure: separate Rust core and Python bindings, remove unused files, and update dependencies * clean up * Restructure wasm_bindings module with initial setup and dependencies * Refactors: * rename WasmDAG -> RustDAG * Added a simple node test file(needs to be fixed) * Refactor ancestor retrieval in RustDAG * refactor `add_nodes_from` & test-wasm.js * Update pyo3 version to 0.21.2 and clean up imports in RustDAG * remove old js folder * update README.md's * initial r bindings changes * R bindings: extendr setup https://extendr.github.io/user-guide/r-pkgs/package-setup.html * R build + Ugly Temp Fixes * Revert dirty fixes by updating R bindings and configuration for compatibility with R 4.2+ * revert config.R temp change * add readme * update README.md * R remote installable by updating rust_core dependency * Add MakeFile & test setup * ci initial setup * Enable manual triggering of CI workflow * Update Python bindings installation in Makefile to use maturin release build * Add installation step for wasm-pack in CI workflow * r deps change * r deps update 2 * R devtools install step * add system dependencies * add pytest installation * macos test * refactor * refactor: ci: update CI configuration for macOS * update README.md * update README.md * test * active_trail_nodes & is_dconnected impl * core: d-separation impl & tests * Independencies initial commit * fix closure & sg3 * add comments * fix(mac build) reduce method: sort assertions on event2 to prevent non generic to be filled first * add sort e2 func * add comments * add more tests * Fix bogus assertion generations + add large test * minor s3contraction fix + test refactors
1 parent e945514 commit 558b0cb

File tree

8 files changed

+1493
-6
lines changed

8 files changed

+1493
-6
lines changed

Cargo.lock

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust_core/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ path = "src/lib.rs"
1111
petgraph = "0.6"
1212
ahash = "0.8"
1313
indexmap = "2.0"
14-
rustworkx-core = "0.14"
14+
rustworkx-core = "0.14"
15+
itertools = "0.14.0"

rust_core/src/dag.rs

Lines changed: 222 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ use petgraph::Direction;
22
use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex};
33
use std::collections::{HashMap, HashSet, VecDeque};
44

5-
// Remove #[pyclass] here. This is a pure Rust struct.
5+
66
#[derive(Debug, Clone)] // Add Debug for easier printing in Rust tests
77
pub struct RustDAG {
88
pub graph: DiGraph<String, f64>, // Make fields public if bindings need direct access,
9-
pub node_map: HashMap<String, NodeIndex>, // or provide internal methods.
9+
pub node_map: HashMap<String, NodeIndex>,
1010
pub reverse_node_map: HashMap<NodeIndex, String>,
1111
pub latents: HashSet<String>,
1212
}
@@ -65,6 +65,26 @@ impl RustDAG {
6565
Ok(())
6666
}
6767

68+
pub fn add_edges_from(
69+
&mut self,
70+
ebunch: Vec<(String, String)>,
71+
weights: Option<Vec<f64>>,
72+
) -> Result<(), String> {
73+
if let Some(ws) = &weights {
74+
if ebunch.len() != ws.len() {
75+
return Err("The number of elements in ebunch and weights should be equal".to_string());
76+
}
77+
for (i, (u, v)) in ebunch.iter().enumerate() {
78+
self.add_edge(u.clone(), v.clone(), Some(ws[i]))?;
79+
}
80+
} else {
81+
for (u, v) in ebunch {
82+
self.add_edge(u, v, None)?;
83+
}
84+
}
85+
Ok(())
86+
}
87+
6888
/// Get parents of a node
6989
pub fn get_parents(&self, node: &str) -> Result<Vec<String>, String> {
7090
let node_idx = self.node_map.get(node)
@@ -122,6 +142,199 @@ impl RustDAG {
122142
Ok(ancestors)
123143
}
124144

145+
146+
pub fn active_trail_nodes(&self, variables: Vec<String>, observed: Option<Vec<String>>, include_latents: bool) -> Result<HashMap<String, HashSet<String>>, String> {
147+
let observed_list: HashSet<String> = observed.unwrap_or_default().into_iter().collect();
148+
// Precompute ancestors of observed nodes (needed for collider rule)
149+
// Example: If C is observed in A→B←C→D, ancestors_list = {A, B, C}
150+
let ancestors_list: HashSet<String> = self.get_ancestors_of(observed_list.iter().cloned().collect())?;
151+
152+
let mut active_trails: HashMap<String, HashSet<String>> = HashMap::new();
153+
// For each starting variable, find all nodes reachable via active trails
154+
for start in variables {
155+
// BFS with direction tracking: (node, direction_of_arrival)
156+
// "up" = coming from child toward parents, "down" = coming from parent toward children
157+
let mut visit_list: HashSet<(String, &str)> = HashSet::new();
158+
let mut traversed_list: HashSet<(String, &str)> = HashSet::new();
159+
let mut active_nodes: HashSet<String> = HashSet::new();
160+
161+
if !self.node_map.contains_key(&start) {
162+
return Err(format!("Node {} not in graph", start));
163+
}
164+
165+
visit_list.insert((start.clone(), "up"));
166+
while let Some((node, direction)) = visit_list.iter().next().map(|x| x.clone()) {
167+
visit_list.remove(&(node.clone(), direction));
168+
if !traversed_list.contains(&(node.clone(), direction)) {
169+
// Add to active trail if not observed (observed nodes block but aren't "reachable")
170+
if !observed_list.contains(&node) {
171+
active_nodes.insert(node.clone());
172+
}
173+
traversed_list.insert((node.clone(), direction));
174+
175+
// If arriving "up" at unobserved B, can continue to parents and switch to children
176+
if direction == "up" && !observed_list.contains(&node) {
177+
for parent in self.get_parents(&node)? {
178+
visit_list.insert((parent, "up")); // Continue up the chain
179+
}
180+
for child in self.get_children(&node)? {
181+
visit_list.insert((child, "down")); // Switch direction
182+
}
183+
}
184+
185+
// If arriving "down", can continue down if unobserved, or go up if it's a collider
186+
else if direction == "down" {
187+
if !observed_list.contains(&node) {
188+
for child in self.get_children(&node)? {
189+
visit_list.insert((child, "down"));
190+
}
191+
}
192+
if ancestors_list.contains(&node) {
193+
for parent in self.get_parents(&node)? {
194+
visit_list.insert((parent, "up"));
195+
}
196+
}
197+
}
198+
}
199+
}
200+
201+
let final_nodes: HashSet<String> = if include_latents {
202+
active_nodes
203+
} else {
204+
active_nodes.difference(&self.latents).cloned().collect()
205+
};
206+
active_trails.insert(start, final_nodes);
207+
}
208+
209+
Ok(active_trails)
210+
}
211+
212+
pub fn is_dconnected(&self, start: &str, end: &str, observed: Option<Vec<String>>, include_latents: bool) -> Result<bool, String> {
213+
let trails = self.active_trail_nodes(vec![start.to_string()], observed, include_latents)?;
214+
Ok(trails.get(start).map(|nodes| nodes.contains(end)).unwrap_or(false))
215+
}
216+
217+
pub fn minimal_dseparator(
218+
&self,
219+
start: &str,
220+
end: &str,
221+
include_latents: bool
222+
) -> Result<Option<HashSet<String>>, String> {
223+
// Example: For DAG A→B←C, B→D, trying to separate A and C
224+
// Adjacent nodes can't be separated by any conditioning set
225+
if self.has_edge(start, end) || self.has_edge(end, start) {
226+
return Err("No possible separators because start and end are adjacent".to_string());
227+
}
228+
229+
// Create ancestral graph containing only ancestors of start and end
230+
// Example: For separating A and D in A→B←C, B→D, ancestral graph = {A, B, C, D}
231+
let ancestral_graph = self.get_ancestral_graph(vec![start.to_string(), end.to_string()])?;
232+
233+
// Initial separator: all parents of both nodes (theoretical upper bound)
234+
// Example: parents(A)={} ∪ parents(D)={B} → separator = {B}
235+
let mut separator: HashSet<String> = self.get_parents(start)?
236+
.into_iter()
237+
.chain(self.get_parents(end)?.into_iter())
238+
.collect();
239+
240+
// Replace latent variables with their observable parents
241+
// Example: If B were latent with parent L, replace B with L in separator
242+
if !include_latents {
243+
let mut changed = true;
244+
while changed {
245+
changed = false;
246+
let mut new_separator: HashSet<String> = HashSet::new();
247+
248+
for node in &separator {
249+
if self.latents.contains(node) {
250+
new_separator.extend(self.get_parents(node)?);
251+
changed = true;
252+
} else {
253+
new_separator.insert(node.clone());
254+
}
255+
}
256+
separator = new_separator;
257+
}
258+
}
259+
260+
separator.remove(start);
261+
separator.remove(end);
262+
263+
// Sanity check: if our "guaranteed" separator doesn't work, no separator exists
264+
if ancestral_graph.is_dconnected(start, end, Some(separator.iter().cloned().collect()), include_latents)? {
265+
return Ok(None);
266+
}
267+
268+
// Greedy minimization: remove each node if separation still holds without it
269+
// Example: If separator = {B, C} but {B} alone separates A from D, remove C
270+
let mut minimal_separator = separator.clone();
271+
for u in separator {
272+
let test_separator: Vec<String> = minimal_separator.iter().cloned().filter(|x| x != &u).collect();
273+
274+
// If still d-separated WITHOUT this node, we can remove it
275+
if !ancestral_graph.is_dconnected(start, end, Some(test_separator), include_latents)? {
276+
minimal_separator.remove(&u);
277+
}
278+
}
279+
280+
Ok(Some(minimal_separator))
281+
}
282+
283+
/// Check if two nodes are neighbors (directly connected in either direction)
284+
pub fn are_neighbors(&self, start: &str, end: &str) -> Result<bool, String> {
285+
let start_idx = self.node_map.get(start)
286+
.ok_or_else(|| format!("Node {} not found", start))?;
287+
let end_idx = self.node_map.get(end)
288+
.ok_or_else(|| format!("Node {} not found", end))?;
289+
290+
// Check for edge in either direction
291+
let has_edge = self.graph.find_edge(*start_idx, *end_idx).is_some() ||
292+
self.graph.find_edge(*end_idx, *start_idx).is_some();
293+
294+
Ok(has_edge)
295+
}
296+
297+
/// Get ancestral graph containing only ancestors of the given nodes
298+
pub fn get_ancestral_graph(&self, nodes: Vec<String>) -> Result<RustDAG, String> {
299+
let ancestors = self.get_ancestors_of(nodes)?;
300+
let mut ancestral_graph = RustDAG::new();
301+
302+
// Add all ancestor nodes with their latent status
303+
for node in &ancestors {
304+
let is_latent = self.latents.contains(node);
305+
ancestral_graph.add_node(node.clone(), is_latent)?;
306+
}
307+
308+
// Add edges between ancestors only
309+
for (source, target) in self.edges() {
310+
if ancestors.contains(&source) && ancestors.contains(&target) {
311+
ancestral_graph.add_edge(source, target, None)?;
312+
}
313+
}
314+
315+
Ok(ancestral_graph)
316+
}
317+
318+
319+
320+
/// Returns a list of leaves (nodes with out-degree 0)
321+
pub fn get_leaves(&self) -> Vec<String> {
322+
self.graph
323+
.node_indices()
324+
.filter(|&idx| self.graph.neighbors_directed(idx, Direction::Outgoing).next().is_none())
325+
.map(|idx| self.reverse_node_map[&idx].clone())
326+
.collect()
327+
}
328+
329+
/// Returns a list of roots (nodes with in-degree 0)
330+
pub fn get_roots(&self) -> Vec<String> {
331+
self.graph
332+
.node_indices()
333+
.filter(|&idx| self.graph.neighbors_directed(idx, Direction::Incoming).next().is_none())
334+
.map(|idx| self.reverse_node_map[&idx].clone())
335+
.collect()
336+
}
337+
125338
/// Get all nodes in the graph
126339
pub fn nodes(&self) -> Vec<String> {
127340
self.node_map.keys().cloned().collect()
@@ -141,6 +354,13 @@ impl RustDAG {
141354
.collect()
142355
}
143356

357+
pub fn has_edge(&self, u: &str, v: &str) -> bool {
358+
match (self.node_map.get(u), self.node_map.get(v)) {
359+
(Some(u_idx), Some(v_idx)) => self.graph.find_edge(*u_idx, *v_idx).is_some(),
360+
_ => false,
361+
}
362+
}
363+
144364
/// Get number of nodes
145365
pub fn node_count(&self) -> usize {
146366
self.graph.node_count()

0 commit comments

Comments
 (0)