Skip to content

Commit 0972b3a

Browse files
committed
Merge branch 'wgraham/signature-converting' into wgraham/distributions-are-backend-agnostic
2 parents 5c27b0b + a463a14 commit 0972b3a

File tree

9 files changed

+366
-108
lines changed

9 files changed

+366
-108
lines changed

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ classifiers = [
1515
"Programming Language :: Python :: 3.13",
1616
"Typing :: Typed",
1717
]
18-
dependencies = ["jax", "networkx"]
18+
dependencies = [
19+
"jax",
20+
"networkx",
21+
"numpy",
22+
]
1923
description = "A Python package for causal modelling and inference with stochastic causal programming"
2024
dynamic = ["version"]
2125
keywords = []
@@ -35,6 +39,7 @@ optional-dependencies = {dev = [
3539
"mkdocstrings-python",
3640
], test = [
3741
"distrax",
42+
"numpy",
3843
"numpyro",
3944
"pytest",
4045
"pytest-cov",
@@ -75,6 +80,7 @@ lint.per-file-ignores = {"__init__.py" = [
7580
"ANN",
7681
"D",
7782
"INP001", # File is part of an implicit namespace package.
83+
"PLR0913", # Too many arguments in function definition
7884
"S101", # Use of `assert` detected
7985
]}
8086
lint.select = ["ALL"]

src/causalprog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""causalprog package."""
22

3-
from . import graph
3+
from . import algorithms, distribution, graph, utils
44
from ._version import __version__

src/causalprog/algorithms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Algorithms."""
2+
3+
from .expectation import expectation, standard_deviation
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Algorithms for estimating the expectation and standard deviation."""
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from causalprog.graph import Graph
7+
8+
9+
def sample(
10+
graph: Graph,
11+
outcome_node_label: str | None = None,
12+
samples: int = 1000,
13+
) -> npt.NDArray[float]:
14+
"""Sample data from a graph."""
15+
if outcome_node_label is None:
16+
outcome_node_label = graph.outcome.label
17+
18+
nodes = graph.roots_down_to_outcome(outcome_node_label)
19+
20+
values: dict[str, npt.NDArray[float]] = {}
21+
for node in nodes:
22+
values[node.label] = node.sample(values, samples)
23+
return values[outcome_node_label]
24+
25+
26+
def expectation(
27+
graph: Graph,
28+
outcome_node_label: str | None = None,
29+
samples: int = 1000,
30+
) -> float:
31+
"""Estimate the expectation of a graph."""
32+
return sample(graph, outcome_node_label, samples).mean()
33+
34+
35+
def standard_deviation(
36+
graph: Graph,
37+
outcome_node_label: str | None = None,
38+
samples: int = 1000,
39+
) -> float:
40+
"""Estimate the standard deviation of a graph."""
41+
return np.std(sample(graph, outcome_node_label, samples))

src/causalprog/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Creation and storage of graphs."""
22

33
from .graph import Graph
4-
from .node import DistributionNode, RootDistributionNode
4+
from .node import DistributionNode, Node

src/causalprog/graph/graph.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,98 @@
1010
class Graph(Labelled):
1111
"""A directed acyclic graph that represents a causality tree."""
1212

13-
def __init__(self, graph: nx.Graph, label: str) -> None:
14-
"""Initialise a graph from a NetworkX graph."""
13+
_nodes_by_label: dict[str, Node]
14+
15+
def __init__(self, label: str) -> None:
16+
"""Create end empty graph."""
1517
super().__init__(label=label)
18+
self._graph = nx.DiGraph()
19+
self._nodes_by_label = {}
20+
21+
def get_node(self, label: str) -> Node:
22+
"""Get a node from its label."""
23+
node = self._nodes_by_label.get(label, None)
24+
if not node:
25+
msg = f'Node not found with label "{label}"'
26+
raise KeyError(msg)
27+
return node
28+
29+
def add_node(self, node: Node) -> None:
30+
"""Add a node to the graph."""
31+
if node.label in self._nodes_by_label:
32+
msg = f"Duplicate node label: {node.label}"
33+
raise ValueError(msg)
34+
self._nodes_by_label[node.label] = node
35+
self._graph.add_node(node)
36+
37+
def add_edge(self, first_node: Node | str, second_node: Node | str) -> None:
38+
"""
39+
Add an edge to the graph.
1640
17-
for node in graph.nodes:
18-
if not isinstance(node, Node):
19-
msg = f"Invalid node: {node}"
20-
raise TypeError(msg)
41+
Adding an edge between nodes not currently in the graph,
42+
will cause said nodes to be added to the graph along with
43+
the edge.
44+
"""
45+
if isinstance(first_node, str):
46+
first_node = self.get_node(first_node)
47+
if isinstance(second_node, str):
48+
second_node = self.get_node(second_node)
49+
if first_node.label not in self._nodes_by_label:
50+
self.add_node(first_node)
51+
if second_node.label not in self._nodes_by_label:
52+
self.add_node(second_node)
53+
for node_to_check in (first_node, second_node):
54+
if node_to_check != self._nodes_by_label[node_to_check.label]:
55+
msg = "Invalid node: {node_to_check}"
56+
raise ValueError(msg)
57+
self._graph.add_edge(first_node, second_node)
2158

22-
self._graph = graph.copy()
23-
self._nodes = list(graph.nodes())
24-
self._depth_first_nodes = list(nx.algorithms.dfs_postorder_nodes(graph))
59+
@property
60+
def predecessors(self) -> dict[Node, Node]:
61+
"""Get predecessors of every node."""
62+
return nx.algorithms.dfs_predecessors(self._graph)
2563

26-
outcomes = [node for node in self._nodes if node.is_outcome]
64+
@property
65+
def successors(self) -> dict[Node, list[Node]]:
66+
"""Get successors of every node."""
67+
return nx.algorithms.dfs_successors(self._graph)
68+
69+
@property
70+
def outcome(self) -> Node:
71+
"""The outcome node of the graph."""
72+
outcomes = [node for node in self.nodes if node.is_outcome]
2773
if len(outcomes) == 0:
2874
msg = "Cannot create graph with no outcome nodes"
2975
raise ValueError(msg)
3076
if len(outcomes) > 1:
3177
msg = "Cannot yet create graph with multiple outcome nodes"
3278
raise ValueError(msg)
33-
self._outcome = outcomes[0]
79+
return outcomes[0]
80+
81+
@property
82+
def nodes(self) -> list[Node]:
83+
"""The nodes of the graph."""
84+
return list(self._graph.nodes())
85+
86+
@property
87+
def ordered_nodes(self) -> list[Node]:
88+
"""Nodes ordered so that each node appears after its dependencies."""
89+
if not nx.is_directed_acyclic_graph(self._graph):
90+
msg = "Graph is not acyclic."
91+
raise RuntimeError(msg)
92+
return list(nx.topological_sort(self._graph))
93+
94+
def roots_down_to_outcome(
95+
self,
96+
outcome_node_label: str,
97+
) -> list[Node]:
98+
"""
99+
Get ordered list of nodes that outcome depends on.
100+
101+
Nodes are ordered so that each node appears after its dependencies.
102+
"""
103+
outcome = self.get_node(outcome_node_label)
104+
ancestors = nx.ancestors(self._graph, outcome)
105+
return [
106+
node for node in self.ordered_nodes if node == outcome or node in ancestors
107+
]

src/causalprog/graph/node.py

Lines changed: 60 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,97 +2,90 @@
22

33
from __future__ import annotations
44

5-
from abc import abstractmethod
6-
from typing import Protocol, runtime_checkable
5+
import typing
6+
from abc import ABC, abstractmethod
77

8-
from causalprog._abc.labelled import Labelled
8+
import numpy as np
99

10+
if typing.TYPE_CHECKING:
11+
import numpy.typing as npt
1012

11-
class DistributionFamily:
12-
"""Placeholder class."""
13+
from causalprog._abc.labelled import Labelled
1314

1415

15-
class Distribution:
16+
class Distribution(ABC):
1617
"""Placeholder class."""
1718

18-
19-
@runtime_checkable
20-
class Node(Protocol):
21-
"""An abstract node in a graph."""
22-
23-
@property
24-
@abstractmethod
25-
def label(self) -> str:
26-
"""The label of the node."""
27-
28-
@property
2919
@abstractmethod
30-
def is_root(self) -> bool:
31-
"""Identify if the node is a root."""
32-
33-
@property
34-
@abstractmethod
35-
def is_outcome(self) -> bool:
36-
"""Identify if the node is an outcome."""
37-
38-
39-
class RootDistributionNode(Labelled):
40-
"""A root node containing a distribution family."""
20+
def sample(
21+
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
22+
) -> npt.NDArray[float]:
23+
"""Sample."""
24+
25+
26+
class NormalDistribution(Distribution):
27+
"""Normal distribution."""
28+
29+
def __init__(self, mean: str | float = 0.0, std_dev: str | float = 1.0) -> None:
30+
"""Initialise."""
31+
self.mean = mean
32+
self.std_dev = std_dev
33+
34+
def sample(
35+
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
36+
) -> npt.NDArray[float]:
37+
"""Sample a normal distribution with mean 1."""
38+
values = np.random.normal(0.0, 1.0, samples) # noqa: NPY002
39+
if isinstance(self.std_dev, str):
40+
values *= sampled_dependencies[self.std_dev]
41+
else:
42+
values *= self.std_dev
43+
if isinstance(self.mean, str):
44+
values += sampled_dependencies[self.mean]
45+
else:
46+
values += self.mean
47+
return values
48+
49+
50+
class Node(Labelled):
51+
"""An abstract node in a graph."""
4152

42-
def __init__(
43-
self,
44-
family: DistributionFamily,
45-
label: str,
46-
*,
47-
is_outcome: bool = False,
48-
) -> None:
49-
"""Initialise the node."""
53+
def __init__(self, label: str, *, is_outcome: bool = False) -> None:
54+
"""Initialise."""
5055
super().__init__(label=label)
56+
self._is_outcome = is_outcome
5157

52-
self._dfamily = family
53-
self._outcome = is_outcome
54-
55-
def __repr__(self) -> str:
56-
"""Representation."""
57-
return f'RootDistributionNode("{self._label}")'
58-
59-
@property
60-
def is_root(self) -> bool:
61-
"""Identify if the node is a root."""
62-
return True
58+
@abstractmethod
59+
def sample(
60+
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
61+
) -> float:
62+
"""Sample a value from the node."""
6363

6464
@property
6565
def is_outcome(self) -> bool:
6666
"""Identify if the node is an outcome."""
67-
return self._outcome
67+
return self._is_outcome
6868

6969

70-
class DistributionNode(Labelled):
71-
"""A node containing a distribution family that depends on its parents."""
70+
class DistributionNode(Node):
71+
"""A node containing a distribution."""
7272

7373
def __init__(
7474
self,
75-
family: DistributionFamily,
75+
distribution: Distribution,
7676
label: str,
7777
*,
7878
is_outcome: bool = False,
7979
) -> None:
80-
"""Initialise the node."""
81-
super().__init__(label=label)
80+
"""Initialise."""
81+
self._dist = distribution
82+
super().__init__(label, is_outcome=is_outcome)
8283

83-
self._dfamily = family
84-
self._outcome = is_outcome
84+
def sample(
85+
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
86+
) -> float:
87+
"""Sample a value from the node."""
88+
return self._dist.sample(sampled_dependencies, samples)
8589

8690
def __repr__(self) -> str:
87-
"""Representation."""
88-
return f'DistributionNode("{self._label}")'
89-
90-
@property
91-
def is_root(self) -> bool:
92-
"""Identify if the node is a root."""
93-
return False
94-
95-
@property
96-
def is_outcome(self) -> bool:
97-
"""Identify if the node is an outcome."""
98-
return self._outcome
91+
return f'DistributionNode("{self.label}")'

tests/test_backend/test_signature_can_be_cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
),
161161
],
162162
)
163-
def test_signature_can_be_cast( # noqa: PLR0913
163+
def test_signature_can_be_cast(
164164
signature_to_convert: Signature,
165165
new_signature: Signature,
166166
old_to_new_names: ParamNameMap,

0 commit comments

Comments
 (0)