Skip to content

Commit 2822e2b

Browse files
committed
Skeleton functionality of the CausalProblem class
1 parent 9b94674 commit 2822e2b

File tree

1 file changed

+75
-3
lines changed

1 file changed

+75
-3
lines changed

src/causalprog/graph/causal_problem.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from typing import TYPE_CHECKING
1+
"""Extension of the Graph class providing features for solving causal problems."""
2+
3+
from collections.abc import Callable
4+
from typing import TYPE_CHECKING, Literal
25

36
from .graph import Graph
47

@@ -7,13 +10,62 @@
710

811

912
class CausalProblem(Graph):
10-
""""""
13+
"""f."""
14+
15+
_sigma: Callable[..., float]
16+
_constraints: Callable[..., float]
1117

12-
def __init__(self, label: str):
18+
def __init__(self, label: str) -> None:
19+
"""Set up a new CausalProblem."""
1320
super().__init__(label)
1421

1522
self.reset_parameters()
1623

24+
def _call_callable_attribute(
25+
self, which: Literal["sigma", "constraints"], *parameter_values: float
26+
) -> float:
27+
"""
28+
Evaluate the causal estimand or the constraints function.
29+
30+
parameter_values should be passed in the order they appear in
31+
self.parameter_nodes.
32+
"""
33+
# Set parameter value as per the inputs.
34+
# Order of *parameter_values is assumed to match the order of
35+
# self.parameter_nodes.
36+
self.set_parameters(
37+
**{
38+
self.parameter_nodes[i].label: value
39+
for i, value in enumerate(parameter_values)
40+
}
41+
)
42+
# Call underlying function
43+
return getattr(self, f"_{which}")()
44+
45+
def _set_callable_attribute(
46+
self,
47+
which: Literal["sigma", "constraints"],
48+
fn: Callable[..., float],
49+
name_map: dict[str, str],
50+
) -> None:
51+
"""
52+
Set either the causal estimand (sigma) or constraints function.
53+
54+
Input ``fn`` is assumed to take random variables as arguments. These are
55+
transformed, via the ``name_map``, into the corresponding ``Node``s in the
56+
``Graph`` describing this causal problem.
57+
"""
58+
setattr(
59+
self,
60+
f"_{which}",
61+
lambda: fn(
62+
**{
63+
rv_name: self.get_node(node_name)
64+
for rv_name, node_name in name_map.items()
65+
}
66+
),
67+
)
68+
1769
def reset_parameters(self) -> None:
1870
"""Clear all current values of parameter nodes."""
1971
self.set_parameters(**{node.label: None for node in self.parameter_nodes})
@@ -28,3 +80,23 @@ def set_parameters(self, **parameter_values: float | None) -> None:
2880
for name, new_value in parameter_values.items():
2981
node: ParameterNode = self.get_node(name)
3082
node.current_value = new_value
83+
84+
def set_causal_estimand(
85+
self, sigma: Callable[..., float], rvs_to_nodes: dict[str, str]
86+
) -> None:
87+
"""Set the causal estimand of this CausalProblem."""
88+
self._set_callable_attribute("sigma", sigma, rvs_to_nodes)
89+
90+
def set_constraints(
91+
self, constraints: Callable[..., float], rvs_to_nodes: dict[str, str]
92+
) -> None:
93+
"""Set the constraints of this CausalProblem."""
94+
self._set_callable_attribute("constraints", constraints, rvs_to_nodes)
95+
96+
def causal_estimand(self, *parameter_values: float) -> float:
97+
"""Evaluate the causal estimand."""
98+
return self._call_callable_attribute("sigma", *parameter_values)
99+
100+
def constraints(self, *parameter_values: float) -> float:
101+
"""Evaluate the constraints function."""
102+
return self._call_callable_attribute("constraints", *parameter_values)

0 commit comments

Comments
 (0)