1
- from typing import TYPE_CHECKING
1
+ """Extension of the Graph class providing features for solving causal problems."""
2
+
3
+ from collections .abc import Callable
4
+ from typing import TYPE_CHECKING , Literal
2
5
3
6
from .graph import Graph
4
7
7
10
8
11
9
12
class CausalProblem (Graph ):
10
- """"""
13
+ """f."""
14
+
15
+ _sigma : Callable [..., float ]
16
+ _constraints : Callable [..., float ]
11
17
12
- def __init__ (self , label : str ):
18
+ def __init__ (self , label : str ) -> None :
19
+ """Set up a new CausalProblem."""
13
20
super ().__init__ (label )
14
21
15
22
self .reset_parameters ()
16
23
24
+ def _call_callable_attribute (
25
+ self , which : Literal ["sigma" , "constraints" ], * parameter_values : float
26
+ ) -> float :
27
+ """
28
+ Evaluate the causal estimand or the constraints function.
29
+
30
+ parameter_values should be passed in the order they appear in
31
+ self.parameter_nodes.
32
+ """
33
+ # Set parameter value as per the inputs.
34
+ # Order of *parameter_values is assumed to match the order of
35
+ # self.parameter_nodes.
36
+ self .set_parameters (
37
+ ** {
38
+ self .parameter_nodes [i ].label : value
39
+ for i , value in enumerate (parameter_values )
40
+ }
41
+ )
42
+ # Call underlying function
43
+ return getattr (self , f"_{ which } " )()
44
+
45
+ def _set_callable_attribute (
46
+ self ,
47
+ which : Literal ["sigma" , "constraints" ],
48
+ fn : Callable [..., float ],
49
+ name_map : dict [str , str ],
50
+ ) -> None :
51
+ """
52
+ Set either the causal estimand (sigma) or constraints function.
53
+
54
+ Input ``fn`` is assumed to take random variables as arguments. These are
55
+ transformed, via the ``name_map``, into the corresponding ``Node``s in the
56
+ ``Graph`` describing this causal problem.
57
+ """
58
+ setattr (
59
+ self ,
60
+ f"_{ which } " ,
61
+ lambda : fn (
62
+ ** {
63
+ rv_name : self .get_node (node_name )
64
+ for rv_name , node_name in name_map .items ()
65
+ }
66
+ ),
67
+ )
68
+
17
69
def reset_parameters (self ) -> None :
18
70
"""Clear all current values of parameter nodes."""
19
71
self .set_parameters (** {node .label : None for node in self .parameter_nodes })
@@ -28,3 +80,23 @@ def set_parameters(self, **parameter_values: float | None) -> None:
28
80
for name , new_value in parameter_values .items ():
29
81
node : ParameterNode = self .get_node (name )
30
82
node .current_value = new_value
83
+
84
+ def set_causal_estimand (
85
+ self , sigma : Callable [..., float ], rvs_to_nodes : dict [str , str ]
86
+ ) -> None :
87
+ """Set the causal estimand of this CausalProblem."""
88
+ self ._set_callable_attribute ("sigma" , sigma , rvs_to_nodes )
89
+
90
+ def set_constraints (
91
+ self , constraints : Callable [..., float ], rvs_to_nodes : dict [str , str ]
92
+ ) -> None :
93
+ """Set the constraints of this CausalProblem."""
94
+ self ._set_callable_attribute ("constraints" , constraints , rvs_to_nodes )
95
+
96
+ def causal_estimand (self , * parameter_values : float ) -> float :
97
+ """Evaluate the causal estimand."""
98
+ return self ._call_callable_attribute ("sigma" , * parameter_values )
99
+
100
+ def constraints (self , * parameter_values : float ) -> float :
101
+ """Evaluate the constraints function."""
102
+ return self ._call_callable_attribute ("constraints" , * parameter_values )
0 commit comments