Skip to content

Commit a6564c0

Browse files
bkayalibaydustinvtran
authored andcommitted
Utility function to assess conditional independence (#791)
* Adds utility functions to assess conditional independence. * Do not return random variables that are part of the query * Added unit tests for edward.util.get_irrelevant * Fixed a typo in the docstring for get_irrelevant * Added unit tests for ed.is_independent * Removed helper function get_irrelevant * Update comments to respect PEP8 * More PEP-8 related changes. * Updated docstrings and added bibtex entry for bayes-ball.
1 parent dec20fc commit a6564c0

File tree

5 files changed

+156
-1
lines changed

5 files changed

+156
-1
lines changed

docs/tex/bib.bib

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ @article{lecun1998gradient
265265
publisher={IEEE}
266266
}
267267

268+
@inproceedings{schachter1998bayes,
269+
title={Bayes-Ball: The Rational Pastime (for Determining Irrelevance and Requisite Information in Belief Networks and Influence Diagrams)},
270+
author={Schachter, Ross D},
271+
booktitle={Proceedings of the Fourteenth Conference in Uncertainty in Artificial Intelligence},
272+
pages={480--487},
273+
year={1998}
274+
}
275+
268276
@article{watts1998collective,
269277
title={Collective dynamics of ‘small-world’networks},
270278
author={Watts, Duncan J and Strogatz, Steven H},

edward/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from edward.util import check_data, check_latent_vars, copy, dot, \
2121
get_ancestors, get_blanket, get_children, get_control_variate_coef, \
2222
get_descendants, get_parents, get_session, get_siblings, get_variables, \
23-
Progbar, random_variables, rbf, set_seed, to_simplex, transform
23+
is_independent, Progbar, random_variables, rbf, set_seed, \
24+
to_simplex, transform
2425
from edward.version import __version__, VERSION
2526

2627
from tensorflow.python.util.all_util import remove_undocumented
@@ -74,6 +75,7 @@
7475
'get_session',
7576
'get_siblings',
7677
'get_variables',
78+
'is_independent',
7779
'Progbar',
7880
'random_variables',
7981
'rbf',

edward/util/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'get_session',
2828
'get_siblings',
2929
'get_variables',
30+
'is_independent',
3031
'Progbar',
3132
'random_variables',
3233
'rbf',

edward/util/random_variables.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from copy import deepcopy
1010
from edward.models.random_variable import RandomVariable
1111
from edward.models.random_variables import TransformedDistribution
12+
from edward.models import PointMass
1213
from edward.util.graphs import random_variables
1314
from tensorflow.contrib.distributions import bijectors
1415
from tensorflow.core.framework import attr_value_pb2
@@ -714,6 +715,85 @@ def get_variables(x, collection=None):
714715
return list(output)
715716

716717

718+
def is_independent(a, b, condition=None):
719+
"""Assess whether a is independent of b given the random variables in
720+
condition.
721+
722+
Implemented using the Bayes-Ball algorithm [@schachter1998bayes].
723+
724+
Args:
725+
a: RandomVariable or list of RandomVariable.
726+
Query node(s).
727+
b: RandomVariable or list of RandomVariable.
728+
Query node(s).
729+
condition: RandomVariable or list of RandomVariable, optional.
730+
Random variable(s) to condition on.
731+
732+
Returns:
733+
bool.
734+
True if a is independent of b given the random variables in condition.
735+
736+
#### Examples
737+
738+
```python
739+
a = Normal(0.0, 1.0)
740+
b = Normal(a, 1.0)
741+
c = Normal(a, 1.0)
742+
assert ed.is_independent(b, c, condition=a)
743+
```
744+
"""
745+
if condition is None:
746+
condition = []
747+
if not isinstance(a, list):
748+
a = [a]
749+
if not isinstance(b, list):
750+
b = [b]
751+
if not isinstance(condition, list):
752+
condition = [condition]
753+
A = set(a)
754+
B = set(b)
755+
condition = set(condition)
756+
757+
top_marked = set()
758+
# The Bayes-Ball algorithm will traverse the belief network
759+
# and add each node that is relevant to B given condition
760+
# to the set bottom_marked. A and B are conditionally
761+
# independent if no node in A is in bottom_marked.
762+
bottom_marked = set()
763+
764+
schedule = [(node, "child") for node in B]
765+
while schedule:
766+
node, came_from = schedule.pop()
767+
768+
if node not in condition and came_from == "child":
769+
if node not in top_marked:
770+
top_marked.add(node)
771+
for parent in get_parents(node):
772+
schedule.append((parent, "child"))
773+
774+
if not isinstance(node, PointMass) and node not in bottom_marked:
775+
bottom_marked.add(node)
776+
if node in A:
777+
return False # node in A is relevant to B
778+
for child in get_children(node):
779+
schedule.append((child, "parent"))
780+
781+
elif came_from == "parent":
782+
if node in condition and node not in top_marked:
783+
top_marked.add(node)
784+
for parent in get_parents(node):
785+
schedule.append((parent, "child"))
786+
787+
elif node not in condition and node not in bottom_marked:
788+
bottom_marked.add(node)
789+
if node in A:
790+
return False # node in A is relevant to B
791+
for child in get_children(node):
792+
schedule.append((child, "parent"))
793+
794+
return True
795+
796+
717797
def transform(x, *args, **kwargs):
718798
"""Transform a continuous random variable to the unconstrained space.
719799

tests/util/test_is_independent.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import tensorflow as tf
6+
7+
from edward.models import Normal
8+
from edward.util import is_independent
9+
10+
11+
class test_is_independent_class(tf.test.TestCase):
12+
13+
def test_chain_structure(self):
14+
"""a -> b -> c -> d -> e"""
15+
a = Normal(0.0, 1.0)
16+
b = Normal(a, 1.0)
17+
c = Normal(b, 1.0)
18+
d = Normal(c, 1.0)
19+
e = Normal(d, 1.0)
20+
self.assertTrue(is_independent(c, e, d))
21+
self.assertTrue(is_independent([a, b, c], e, d))
22+
self.assertTrue(is_independent([a, b], [d, e], c))
23+
self.assertFalse(is_independent([a, b, e], d, c))
24+
25+
def test_binary_structure(self):
26+
"""f <- c <- a -> b -> d
27+
| |
28+
v v
29+
g e
30+
"""
31+
a = Normal(0.0, 1.0)
32+
b = Normal(a, 1.0)
33+
c = Normal(a, 1.0)
34+
d = Normal(b, 1.0)
35+
e = Normal(b, 1.0)
36+
f = Normal(c, 1.0)
37+
g = Normal(c, 1.0)
38+
self.assertFalse(is_independent(b, c))
39+
self.assertTrue(is_independent(b, c, a))
40+
self.assertTrue(is_independent(d, [a, c, e, f, g], b))
41+
self.assertFalse(is_independent(b, [e, d], a))
42+
self.assertFalse(is_independent(a, [b, c, d, e, f, g]))
43+
44+
def test_grid_structure(self):
45+
"""a -> b -> c
46+
| | |
47+
v v v
48+
d -> e -> f
49+
"""
50+
a = Normal(0.0, 1.0)
51+
b = Normal(a, 1.0)
52+
c = Normal(b, 1.0)
53+
d = Normal(a, 1.0)
54+
e = Normal(b + d, 1.0)
55+
f = Normal(e + c, 1.0)
56+
self.assertFalse(is_independent(f, [a, b, d]))
57+
self.assertTrue(is_independent(f, [a, b, d], [e, c]))
58+
self.assertTrue(is_independent(e, [a, c], [b, d]))
59+
self.assertFalse(is_independent(e, f, [b, d]))
60+
self.assertFalse(is_independent(e, f, [a, b, c, d]))
61+
62+
63+
if __name__ == '__main__':
64+
tf.test.main()

0 commit comments

Comments
 (0)