Skip to content

Do not add internal RandomVariable's in RV collection #609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion edward/inferences/conjugacy/conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def complete_conditional(rv, cond_set=None):
result in unpredictable behavior.
"""
if cond_set is None:
# Default to Markov blanket, excluding conditionals. This is useful if
# calling complete_conditional many times without passing in cond_set.
cond_set = get_blanket(rv)
cond_set = [i for i in cond_set if not
('complete_conditional' in i.name and 'cond_dist' in i.name)]

cond_set = set([rv] + list(cond_set))
with tf.name_scope('complete_conditional_%s' % rv.name) as scope:
Expand Down Expand Up @@ -143,7 +147,9 @@ def complete_conditional(rv, cond_set=None):

def get_log_joint(cond_set):
g = tf.get_default_graph()
cond_set_name = 'log_joint_of_' + ('_'.join([i.name[:-1] for i in cond_set]))
cond_set_names = [i.name[:-1] for i in cond_set]
cond_set_names.sort()
cond_set_name = 'log_joint_of_' + '_'.join(cond_set_names)
with tf.name_scope("conjugate_log_joint/") as scope:
try:
# Use log joint tensor if already built in graph.
Expand Down
3 changes: 2 additions & 1 deletion edward/models/dirichlet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __init__(self, alpha, base, validate_args=False, allow_nan_stats=True,
dtype=self._base.dtype)

# Instantiate distribution for stick breaking proportions.
self._betadist = Beta(a=tf.ones_like(self._alpha), b=self._alpha)
self._betadist = Beta(a=tf.ones_like(self._alpha), b=self._alpha,
collections=[])
# Form empty tensor to store stick breaking proportions.
self._beta = tf.zeros(
[0] + self.get_batch_shape().as_list(),
Expand Down
1 change: 1 addition & 0 deletions edward/models/param_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self,
self._components = component_dist(validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
sample_shape=sample_shape,
collections=[],
**component_params)

if validate_args:
Expand Down
11 changes: 9 additions & 2 deletions edward/models/random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except Exception as e:
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))

RANDOM_VARIABLE_COLLECTION = "_random_variable_collection_"
RANDOM_VARIABLE_COLLECTION = "random_variables"


class RandomVariable(object):
Expand Down Expand Up @@ -78,6 +78,9 @@ def __init__(self, *args, **kwargs):
value : tf.Tensor, optional
Fixed tensor to associate with random variable. Must have shape
``sample_shape + batch_shape + event_shape``.
collections : list, optional
Optional list of graph collections keys. The random variable is
added to these collections. Defaults to ["random_variables"].
*args, **kwargs
Passed into parent ``__init__``.
"""
Expand All @@ -88,11 +91,14 @@ def __init__(self, *args, **kwargs):
# temporarily pop (then reinsert) before calling parent __init__
sample_shape = kwargs.pop('sample_shape', ())
value = kwargs.pop('value', None)
collections = kwargs.pop('collections', [RANDOM_VARIABLE_COLLECTION])
super(RandomVariable, self).__init__(*args, **kwargs)
if sample_shape != ():
self._kwargs['sample_shape'] = sample_shape
if value is not None:
self._kwargs['value'] = value
if collections != [RANDOM_VARIABLE_COLLECTION]:
self._kwargs['collections'] = collections

self._sample_shape = tf.TensorShape(sample_shape)
if value is not None:
Expand All @@ -115,7 +121,8 @@ def __init__(self, *args, **kwargs):
"value argument or implement sample for {0}."
.format(self.__class__.__name__))

tf.add_to_collection(RANDOM_VARIABLE_COLLECTION, self)
for collection in collections:
tf.add_to_collection(collection, self)

@property
def shape(self):
Expand Down
9 changes: 4 additions & 5 deletions examples/mixture_gaussian_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@
z = x.cat

# Conditionals
blanket = [x, z, mu, pi, sigmasq]
mu_cond = ed.complete_conditional(mu, blanket)
sigmasq_cond = ed.complete_conditional(sigmasq, blanket)
pi_cond = ed.complete_conditional(pi, blanket)
z_cond = ed.complete_conditional(z, blanket)
mu_cond = ed.complete_conditional(mu)
sigmasq_cond = ed.complete_conditional(sigmasq)
pi_cond = ed.complete_conditional(pi)
z_cond = ed.complete_conditional(z)

sess = ed.get_session()

Expand Down
10 changes: 4 additions & 6 deletions tests/test-inferences/test_conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def test_dirichlet_categorical(self):
theta = rvs.Dirichlet(alpha)
x = rvs.Categorical(p=theta, sample_shape=sample_shape)

blanket = [theta, x]
theta_cond = ed.complete_conditional(theta, blanket)
theta_cond = ed.complete_conditional(theta, [theta, x])

with self.test_session() as sess:
alpha_val = sess.run(theta_cond.alpha, {x: x_data})
Expand All @@ -208,10 +207,9 @@ def test_mog(self):
rvs.Normal, sample_shape=N)
z = x.cat

blanket = [x, z, mu, pi]
mu_cond = ed.complete_conditional(mu, blanket)
pi_cond = ed.complete_conditional(pi, blanket)
z_cond = ed.complete_conditional(z, blanket)
mu_cond = ed.complete_conditional(mu)
pi_cond = ed.complete_conditional(pi)
z_cond = ed.complete_conditional(z)

with self.test_session() as sess:
pi_cond_alpha, mu_cond_mu, mu_cond_sigma, z_cond_p = (
Expand Down