Skip to content

ABC submodule with Labelled mixin #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/causalprog/_abc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Abstract Base Classes (ABCs) for ``causalprog`` package."""
23 changes: 23 additions & 0 deletions src/causalprog/_abc/labelled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC


class Labelled(ABC):
"""
ABC for objects that carry a label. This class can be used as a MixIn.

Objects must be passed an explicit ``label`` parameter on instantiation,
which provides a name for the object. This value is stored in the
private ``_label`` attribute, and is only intended to be accessed via the
``label`` property of the class.
"""

__slots__ = ("_label",)
_label: str

@property
def label(self) -> str:
"""Label of this object."""
return self._label

def __init__(self, *, label: str) -> None:
self._label = str(label)
7 changes: 6 additions & 1 deletion src/causalprog/distribution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from numpy.typing import ArrayLike

from causalprog._abc.labelled import Labelled
from causalprog.utils.translator import Translator

SupportsRNG = TypeVar("SupportsRNG")
Expand All @@ -30,7 +31,7 @@ def compulsory_frontend_args(self) -> set[str]:
return {"rng_key", "sample_shape"}


class Distribution(Generic[SupportsSampling]):
class Distribution(Generic[SupportsSampling], Labelled):
"""A (backend-agnostic) distribution that can be sampled from."""

_dist: SupportsSampling
Expand All @@ -45,6 +46,8 @@ def __init__(
self,
backend_distribution: SupportsSampling,
backend_translator: SampleTranslator | None = None,
*,
label: str = "Distribution",
) -> None:
"""
Create a new Distribution.
Expand All @@ -56,6 +59,8 @@ def __init__(
sampling function to frontend arguments.

"""
super().__init__(label=label)

self._dist = backend_distribution

# Setup sampling calls, and perform one-time check for compatibility
Expand Down
10 changes: 8 additions & 2 deletions src/causalprog/distribution/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from numpy.typing import ArrayLike

from causalprog._abc.labelled import Labelled
from causalprog.distribution.base import Distribution, SupportsSampling
from causalprog.utils.translator import Translator

Expand All @@ -13,7 +14,7 @@
)


class DistributionFamily(Generic[CreatesDistribution]):
class DistributionFamily(Generic[CreatesDistribution], Labelled):
r"""
A family of ``Distributions``, that share the same parameters.

Expand All @@ -39,13 +40,16 @@ class DistributionFamily(Generic[CreatesDistribution]):
def _member(self) -> Callable[..., Distribution]:
"""Constructor method for family members, given parameters."""
return lambda *parameters: Distribution(
self._family(*parameters), backend_translator=self._family_translator
self._family(*parameters),
backend_translator=self._family_translator,
)

def __init__(
self,
backend_family: CreatesDistribution,
backend_translator: Translator | None = None,
*,
family_name: str = "DistributionFamily",
) -> None:
"""
Create a new family of distributions.
Expand All @@ -58,6 +62,8 @@ def __init__(
passed to the ``Distribution`` constructor.

"""
super().__init__(label=family_name)

self._family = backend_family
self._family_translator = backend_translator

Expand Down
4 changes: 2 additions & 2 deletions src/causalprog/distribution/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None:
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.

"""
super().__init__(_Normal(mean, cov))
super().__init__(_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal")


class NormalFamily(DistributionFamily):
Expand All @@ -74,7 +74,7 @@ class NormalFamily(DistributionFamily):

def __init__(self) -> None:
"""Create a family of normal distributions."""
super().__init__(Normal)
super().__init__(Normal, family_name="Normal")

def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
r"""
Expand Down
13 changes: 5 additions & 8 deletions src/causalprog/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@

import networkx as nx

from causalprog._abc.labelled import Labelled

from .node import Node


class Graph:
class Graph(Labelled):
"""A directed acyclic graph that represents a causality tree."""

def __init__(self, graph: nx.Graph, label: str) -> None:
"""Initialise a graph from a NetworkX graph."""
super().__init__(label=label)

for node in graph.nodes:
if not isinstance(node, Node):
msg = f"Invalid node: {node}"
raise TypeError(msg)

self._label = label

self._graph = graph.copy()
self._nodes = list(graph.nodes())
self._depth_first_nodes = list(nx.algorithms.dfs_postorder_nodes(graph))
Expand All @@ -29,8 +31,3 @@ def __init__(self, graph: nx.Graph, label: str) -> None:
msg = "Cannot yet create graph with multiple outcome nodes"
raise ValueError(msg)
self._outcome = outcomes[0]

@property
def label(self) -> str:
"""The label of the graph."""
return self._label
22 changes: 8 additions & 14 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from abc import abstractmethod
from typing import Protocol, runtime_checkable

from causalprog._abc.labelled import Labelled


class DistributionFamily:
"""Placeholder class."""
Expand Down Expand Up @@ -34,7 +36,7 @@ def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""


class RootDistributionNode:
class RootDistributionNode(Labelled):
"""A root node containing a distribution family."""

def __init__(
Expand All @@ -45,19 +47,15 @@ def __init__(
is_outcome: bool = False,
) -> None:
"""Initialise the node."""
super().__init__(label=label)

self._dfamily = family
self._label = label
self._outcome = is_outcome

def __repr__(self) -> str:
"""Representation."""
return f'RootDistributionNode("{self._label}")'

@property
def label(self) -> str:
"""The label of the node."""
return self._label

@property
def is_root(self) -> bool:
"""Identify if the node is a root."""
Expand All @@ -69,7 +67,7 @@ def is_outcome(self) -> bool:
return self._outcome


class DistributionNode:
class DistributionNode(Labelled):
"""A node containing a distribution family that depends on its parents."""

def __init__(
Expand All @@ -80,19 +78,15 @@ def __init__(
is_outcome: bool = False,
) -> None:
"""Initialise the node."""
super().__init__(label=label)

self._dfamily = family
self._label = label
self._outcome = is_outcome

def __repr__(self) -> str:
"""Representation."""
return f'DistributionNode("{self._label}")'

@property
def label(self) -> str:
"""The label of the node."""
return self._label

@property
def is_root(self) -> bool:
"""Identify if the node is a root."""
Expand Down