Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ WASM_BINDINGS := wasm_bindings
R_BINDINGS := r_bindings/causalgraphs

# Build targets
.PHONY: all core python wasm r install test clean
.PHONY: all core python wasm r install test format clean

all: core python wasm r

Expand Down Expand Up @@ -56,9 +56,13 @@ test-wasm:
test-r:
cd $(R_BINDINGS) && Rscript -e 'devtools::test()'

format:
@echo "\n=== Formatting Code ==="
cargo fmt --all

clean:
@echo "\n=== Cleaning All Build Artifacts ==="
cd $(RUST_CORE) && cargo clean
cd $(PY_BINDINGS) && rm -rf target/ *.so
cd $(WASM_BINDINGS) && rm -rf js/pkg-* node_modules
cd $(R_BINDINGS) && rm -rf src/rust/target src/.cargo
cd $(R_BINDINGS) && rm -rf src/rust/target src/.cargo
86 changes: 57 additions & 29 deletions python_bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::prelude::*;
use pyo3::exceptions::{PyKeyError, PyValueError};
use rust_core::{RustDAG, IndependenceAssertion, Independencies};
use pyo3::prelude::*;
use rust_core::{IndependenceAssertion, Independencies, RustDAG};
use std::collections::HashSet;

#[pyclass(name = "DAG")]
Expand All @@ -13,21 +13,30 @@ pub struct PyRustDAG {
impl PyRustDAG {
#[new]
pub fn new() -> Self {
PyRustDAG { inner: RustDAG::new() }
PyRustDAG {
inner: RustDAG::new(),
}
}

pub fn add_node(&mut self, node: String, latent: Option<bool>) -> PyResult<()> {
self.inner.add_node(node, latent.unwrap_or(false))
self.inner
.add_node(node, latent.unwrap_or(false))
.map_err(PyValueError::new_err)
}

pub fn add_nodes_from(&mut self, nodes: Vec<String>, latent: Option<Vec<bool>>) -> PyResult<()> {
self.inner.add_nodes_from(nodes, latent)
pub fn add_nodes_from(
&mut self,
nodes: Vec<String>,
latent: Option<Vec<bool>>,
) -> PyResult<()> {
self.inner
.add_nodes_from(nodes, latent)
.map_err(PyValueError::new_err)
}

pub fn add_edge(&mut self, u: String, v: String, weight: Option<f64>) -> PyResult<()> {
self.inner.add_edge(u, v, weight)
self.inner
.add_edge(u, v, weight)
.map_err(PyValueError::new_err)
}

Expand All @@ -36,22 +45,22 @@ impl PyRustDAG {
ebunch: Vec<(String, String)>,
weights: Option<Vec<f64>>,
) -> PyResult<()> {
self.inner.add_edges_from(ebunch, weights)
self.inner
.add_edges_from(ebunch, weights)
.map_err(PyValueError::new_err)
}

pub fn get_parents(&self, node: String) -> PyResult<Vec<String>> {
self.inner.get_parents(&node)
.map_err(PyKeyError::new_err)
self.inner.get_parents(&node).map_err(PyKeyError::new_err)
}

pub fn get_children(&self, node: String) -> PyResult<Vec<String>> {
self.inner.get_children(&node)
.map_err(PyKeyError::new_err)
self.inner.get_children(&node).map_err(PyKeyError::new_err)
}

pub fn get_ancestors_of(&self, nodes: Vec<String>) -> PyResult<HashSet<String>> {
self.inner.get_ancestors_of(nodes)
self.inner
.get_ancestors_of(nodes)
.map_err(PyValueError::new_err)
}

Expand All @@ -78,7 +87,8 @@ impl PyRustDAG {
observed: Option<Vec<String>>,
include_latents: bool,
) -> PyResult<std::collections::HashMap<String, std::collections::HashSet<String>>> {
self.inner.active_trail_nodes(variables, observed, include_latents)
self.inner
.active_trail_nodes(variables, observed, include_latents)
.map_err(PyValueError::new_err)
}

Expand All @@ -90,17 +100,20 @@ impl PyRustDAG {
observed: Option<Vec<String>>,
include_latents: bool,
) -> PyResult<bool> {
self.inner.is_dconnected(&start, &end, observed, include_latents)
self.inner
.is_dconnected(&start, &end, observed, include_latents)
.map_err(PyValueError::new_err)
}

pub fn are_neighbors(&self, start: String, end: String) -> PyResult<bool> {
self.inner.are_neighbors(&start, &end)
self.inner
.are_neighbors(&start, &end)
.map_err(PyValueError::new_err)
}

pub fn get_ancestral_graph(&self, nodes: Vec<String>) -> PyResult<PyRustDAG> {
self.inner.get_ancestral_graph(nodes)
self.inner
.get_ancestral_graph(nodes)
.map(|dag| PyRustDAG { inner: dag })
.map_err(PyValueError::new_err)
}
Expand All @@ -112,7 +125,8 @@ impl PyRustDAG {
end: String,
include_latents: bool,
) -> PyResult<Option<std::collections::HashSet<String>>> {
self.inner.minimal_dseparator(&start, &end, include_latents)
self.inner
.minimal_dseparator(&start, &end, include_latents)
.map_err(PyValueError::new_err)
}
}
Expand All @@ -126,12 +140,15 @@ pub struct PyIndependenceAssertion {
#[pymethods]
impl PyIndependenceAssertion {
#[new]
pub fn new(event1: Vec<String>, event2: Vec<String>, event3: Option<Vec<String>>) -> PyResult<Self> {
pub fn new(
event1: Vec<String>,
event2: Vec<String>,
event3: Option<Vec<String>>,
) -> PyResult<Self> {
let e1: HashSet<String> = event1.into_iter().collect();
let e2: HashSet<String> = event2.into_iter().collect();
let e3: Option<HashSet<String>> = event3.map(|v| v.into_iter().collect());
let assertion = IndependenceAssertion::new(e1, e2, e3)
.map_err(PyValueError::new_err)?;
let assertion = IndependenceAssertion::new(e1, e2, e3).map_err(PyValueError::new_err)?;
Ok(PyIndependenceAssertion { inner: assertion })
}

Expand Down Expand Up @@ -194,28 +211,36 @@ pub struct PyIndependencies {
impl PyIndependencies {
#[new]
pub fn new() -> Self {
PyIndependencies { inner: Independencies::new() }
PyIndependencies {
inner: Independencies::new(),
}
}

pub fn add_assertion(&mut self, assertion: &PyIndependenceAssertion) {
self.inner.add_assertion(assertion.inner.clone());
}

pub fn add_assertions_from_tuples(&mut self, tuples: Vec<(Vec<String>, Vec<String>, Option<Vec<String>>)>) -> PyResult<()> {
self.inner.add_assertions_from_tuples(tuples)
pub fn add_assertions_from_tuples(
&mut self,
tuples: Vec<(Vec<String>, Vec<String>, Option<Vec<String>>)>,
) -> PyResult<()> {
self.inner
.add_assertions_from_tuples(tuples)
.map_err(PyValueError::new_err)
}

pub fn get_assertions(&self) -> Vec<PyIndependenceAssertion> {
self.inner.get_assertions()
self.inner
.get_assertions()
.iter()
.map(|a| PyIndependenceAssertion { inner: a.clone() })
.collect()
}

#[getter(independencies)]
pub fn get_independencies(&self) -> Vec<PyIndependenceAssertion> {
self.inner.get_assertions()
self.inner
.get_assertions()
.iter()
.map(|a| PyIndependenceAssertion { inner: a.clone() })
.collect()
Expand All @@ -230,7 +255,9 @@ impl PyIndependencies {
}

pub fn closure(&self) -> PyIndependencies {
PyIndependencies { inner: self.inner.closure() }
PyIndependencies {
inner: self.inner.closure(),
}
}

#[pyo3(signature = (inplace = false))]
Expand All @@ -239,7 +266,9 @@ impl PyIndependencies {
self.inner.reduce_inplace();
Ok(None)
} else {
Ok(Some(PyIndependencies { inner: self.inner.reduce() }))
Ok(Some(PyIndependencies {
inner: self.inner.reduce(),
}))
}
}

Expand Down Expand Up @@ -267,4 +296,3 @@ fn causalgraphs(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<PyIndependencies>()?;
Ok(())
}

45 changes: 30 additions & 15 deletions r_bindings/causalgraphs/src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,37 @@ impl RDAG {
/// Create a new DAG
/// @export
fn new() -> Self {
RDAG { inner: RustDAG::new() }
RDAG {
inner: RustDAG::new(),
}
}

/// Add a single node to the DAG
/// @param node The node name
/// @param latent Whether the node is latent (default: FALSE)
/// @export
fn add_node(&mut self, node: String, latent: Option<bool>) -> extendr_api::Result<()> {
self.inner.add_node(node, latent.unwrap_or(false))
self.inner
.add_node(node, latent.unwrap_or(false))
.map_err(Error::from)
}


/// Add multiple nodes to the DAG
/// @param nodes Vector of node names
/// @param latent Optional vector of latent flags
/// @export
fn add_nodes_from(&mut self, nodes: Strings, latent: Nullable<Logicals>) -> extendr_api::Result<()> {
fn add_nodes_from(
&mut self,
nodes: Strings,
latent: Nullable<Logicals>,
) -> extendr_api::Result<()> {
let node_vec: Vec<String> = nodes.iter().map(|s| s.to_string()).collect();
let latent_opt: Option<Vec<bool>> = latent.into_option().map(|v| v.iter().map(|x| x.is_true()).collect());

self.inner.add_nodes_from(node_vec, latent_opt)
let latent_opt: Option<Vec<bool>> = latent
.into_option()
.map(|v| v.iter().map(|x| x.is_true()).collect());

self.inner
.add_nodes_from(node_vec, latent_opt)
.map_err(|e| Error::Other(e))
}

Expand All @@ -44,24 +53,24 @@ impl RDAG {
/// @export
fn add_edge(&mut self, u: String, v: String, weight: Nullable<f64>) -> extendr_api::Result<()> {
let w = weight.into_option();
self.inner.add_edge(u, v, w)
.map_err(|e| Error::Other(e))
self.inner.add_edge(u, v, w).map_err(|e| Error::Other(e))
}

/// Get parents of a node
/// @param node The node name
/// @export
fn get_parents(&self, node: String) -> extendr_api::Result<Strings> {
let parents = self.inner.get_parents(&node)
.map_err(|e| Error::Other(e))?;
let parents = self.inner.get_parents(&node).map_err(|e| Error::Other(e))?;
Ok(parents.iter().map(|s| s.as_str()).collect::<Strings>())
}

/// Get children of a node
/// @param node The node name
/// @export
fn get_children(&self, node: String) -> extendr_api::Result<Strings> {
let children = self.inner.get_children(&node)
let children = self
.inner
.get_children(&node)
.map_err(|e| Error::Other(e))?;
Ok(children.iter().map(|s| s.as_str()).collect::<Strings>())
}
Expand All @@ -71,7 +80,9 @@ impl RDAG {
/// @export
fn get_ancestors_of(&self, nodes: Strings) -> extendr_api::Result<Strings> {
let node_vec: Vec<String> = nodes.iter().map(|s| s.to_string()).collect();
let ancestors = self.inner.get_ancestors_of(node_vec)
let ancestors = self
.inner
.get_ancestors_of(node_vec)
.map_err(|e| Error::Other(e))?;
Ok(ancestors.iter().map(|s| s.as_str()).collect::<Strings>())
}
Expand Down Expand Up @@ -106,7 +117,11 @@ impl RDAG {
/// Get latent nodes
/// @export
fn latents(&self) -> Strings {
self.inner.latents.iter().map(|s| s.as_str()).collect::<Strings>()
self.inner
.latents
.iter()
.map(|s| s.as_str())
.collect::<Strings>()
}
}

Expand All @@ -116,4 +131,4 @@ impl RDAG {
extendr_module! {
mod causalgraphs;
impl RDAG;
}
}
Loading
Loading