Skip to content
Merged
4 changes: 1 addition & 3 deletions src/causalprog/backend/_convert_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def _check_variable_length_params(
parameter of that type exists in the signature.

"""
named_args: dict[ParamKind, str | None] = {
kind: None for kind in _VARLENGTH_PARAM_TYPES
}
named_args: dict[ParamKind, str | None] = dict.fromkeys(_VARLENGTH_PARAM_TYPES)
for kind in _VARLENGTH_PARAM_TYPES:
possible_parameters = [
p_name for p_name, p in sig.parameters.items() if p.kind == kind
Expand Down
2 changes: 1 addition & 1 deletion src/causalprog/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Creation and storage of graphs."""

from .graph import Graph
from .node import DistributionNode, Node
from .node import DistributionNode, Node, ParameterNode
39 changes: 37 additions & 2 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
class Node(Labelled):
"""An abstract node in a graph."""

def __init__(self, label: str, *, is_outcome: bool = False) -> None:
def __init__(
self,
label: str,
*,
is_outcome: bool = False,
is_parameter: bool = False,
) -> None:
"""Initialise."""
super().__init__(label=label)
self._is_outcome = is_outcome
self._is_parameter = is_parameter

@abstractmethod
def sample(
Expand All @@ -38,6 +45,11 @@ def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""
return self._is_outcome

@property
def is_parameter(self) -> bool:
"""Identify if the node is a parameter."""
return self._is_parameter


class DistributionNode(Node):
"""A node containing a distribution."""
Expand All @@ -55,7 +67,7 @@ def __init__(
self._dist = distribution
self._constant_parameters = constant_parameters if constant_parameters else {}
self._parameters = parameters if parameters else {}
super().__init__(label, is_outcome=is_outcome)
super().__init__(label, is_outcome=is_outcome, is_parameter=False)

def sample(
self,
Expand All @@ -81,3 +93,26 @@ def sample(

def __repr__(self) -> str:
return f'DistributionNode("{self.label}")'


class ParameterNode(Node):
"""A node containing a parameter."""

def __init__(self, label: str, *, is_outcome: bool = False) -> None:
"""Initialise."""
super().__init__(label, is_outcome=is_outcome, is_parameter=True)

def sample(
self,
sampled_dependencies: dict[str, npt.NDArray[float]],
_samples: int,
_rng_key: jax.Array,
) -> npt.NDArray[float]:
"""Sample a value from the node."""
if self.label not in sampled_dependencies:
msg = "Cannot sample an undetermined parameter node."
raise ValueError(msg)
return sampled_dependencies[self.label]

def __repr__(self) -> str:
return f'ParameterNode("{self.label}")'
11 changes: 10 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import causalprog
from causalprog.distribution.normal import NormalFamily
from causalprog.graph import DistributionNode, Graph
from causalprog.graph import DistributionNode, Graph, ParameterNode


def test_label():
Expand Down Expand Up @@ -236,3 +236,12 @@ def test_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key):
np.sqrt(stdev**2 + stdev2**2),
rtol=rtol,
)


def test_paramater_node(rng_key):
node = ParameterNode("mu")

with pytest.raises(ValueError, match="Cannot sample"):
node.sample({}, 1, rng_key)

assert np.isclose(node.sample({"mu": np.array([0.3])}, 1, rng_key)[0], 0.3)