Skip to content

Commit 49d5f27

Browse files
committed
replace pp_dirichlet_process_{rv,base}.py
1 parent 5604da3 commit 49d5f27

File tree

2 files changed

+19
-92
lines changed

2 files changed

+19
-92
lines changed

examples/pp_dirichlet_process_base.py

Lines changed: 19 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
"""Dirichlet process.
33
44
We sample from a Dirichlet process (with inputted base distribution)
5-
via its stick breaking construction.
5+
via its stick breaking construction. For details on the
6+
implementation, see the source for ``DirichletProcess``.
67
78
References
89
----------
@@ -14,69 +15,24 @@
1415

1516
import tensorflow as tf
1617

17-
from edward.models import Bernoulli, Beta, Normal
18+
from edward.models import DirichletProcess, Exponential, Normal
1819

20+
base_cls = Normal
21+
kwargs = {'mu': 0.0, 'sigma': 1.0}
22+
dp = DirichletProcess(0.1, base_cls, **kwargs)
23+
print(dp)
1924

20-
def dirichlet_process(alpha, base_cls, sample_n=50, *args, **kwargs):
21-
"""Dirichlet process DP(``alpha``, ``base_cls(*args, **kwargs)``).
25+
# ``theta`` is the distribution indirectly returned by the DP.
26+
theta = base_cls(value=tf.cast(dp, tf.float32), **kwargs)
27+
print(theta)
2228

23-
Only works for scalar alpha and scalar base distribution.
24-
25-
Parameters
26-
----------
27-
alpha : tf.Tensor
28-
Concentration parameter. Its shape determines the batch shape of the DP.
29-
base_cls : RandomVariable
30-
Class of base distribution. Its shape (when instantiated)
31-
determines the event shape of the DP.
32-
sample_n : int, optional
33-
Number of samples for each DP in the batch shape.
34-
*args, **kwargs : optional
35-
Arguments passed into ``base_cls``.
36-
37-
Returns
38-
-------
39-
tf.Tensor
40-
A ``tf.Tensor`` of shape ``[sample_n] + batch_shape + event_shape``,
41-
where ``sample_n`` is the number of samples for each DP,
42-
``batch_shape`` is the number of independent DPs, and
43-
``event_shape`` is the shape of the base distribution.
44-
"""
45-
def cond(k, beta_k, draws, bools):
46-
# Proceed if at least one bool is True.
47-
return tf.reduce_any(bools)
48-
49-
def body(k, beta_k, draws, bools):
50-
k = k + 1
51-
beta_k = beta_k * Beta(a=1.0, b=alpha).sample(sample_n)
52-
theta_k = base_cls(*args, **kwargs).sample(sample_n)
53-
54-
# Assign True samples to the new theta_k.
55-
draws = tf.where(bools, theta_k, draws)
56-
57-
flips = tf.cast(Bernoulli(p=beta_k), tf.bool)
58-
bools = tf.logical_and(flips, tf.equal(draws, theta_k))
59-
return k, beta_k, draws, bools
60-
61-
k = 0
62-
beta_k = Beta(a=1.0, b=alpha).sample(sample_n)
63-
theta_k = base_cls(*args, **kwargs).sample(sample_n)
64-
65-
# Initialize all samples as theta_k.
66-
draws = theta_k
67-
# Flip ``sample_n`` coins, one for each sample.
68-
flips = tf.cast(Bernoulli(p=beta_k), tf.bool)
69-
# Get boolean tensor for samples that return heads
70-
# and are currently equal to theta_k.
71-
bools = tf.logical_and(flips, tf.equal(draws, theta_k))
72-
73-
total_sticks, _, samples, _ = tf.while_loop(
74-
cond, body, loop_vars=[k, beta_k, draws, bools])
75-
return total_sticks, samples
76-
77-
78-
dp = dirichlet_process(0.1, Normal, mu=0.0, sigma=1.0)
29+
# Fetching theta is the same as fetching the Dirichlet process.
7930
sess = tf.Session()
80-
print(sess.run(dp))
81-
print(sess.run(dp))
82-
print(sess.run(dp))
31+
print(sess.run([dp, theta]))
32+
print(sess.run([dp, theta]))
33+
34+
# This also works for non-scalar base distributions.
35+
base_cls = Exponential
36+
kwargs = {'lam': tf.ones([5, 2])}
37+
dp = DirichletProcess(0.1, base_cls, **kwargs)
38+
print(dp)

examples/pp_dirichlet_process_rv.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)