|
2 | 2 | """Dirichlet process.
|
3 | 3 |
|
4 | 4 | We sample from a Dirichlet process (with inputted base distribution)
|
5 |
| -via its stick breaking construction. |
| 5 | +via its stick breaking construction. For details on the |
| 6 | +implementation, see the source for ``DirichletProcess``. |
6 | 7 |
|
7 | 8 | References
|
8 | 9 | ----------
|
|
14 | 15 |
|
15 | 16 | import tensorflow as tf
|
16 | 17 |
|
17 |
| -from edward.models import Bernoulli, Beta, Normal |
| 18 | +from edward.models import DirichletProcess, Exponential, Normal |
18 | 19 |
|
| 20 | +base_cls = Normal |
| 21 | +kwargs = {'mu': 0.0, 'sigma': 1.0} |
| 22 | +dp = DirichletProcess(0.1, base_cls, **kwargs) |
| 23 | +print(dp) |
19 | 24 |
|
20 |
| -def dirichlet_process(alpha, base_cls, sample_n=50, *args, **kwargs): |
21 |
| - """Dirichlet process DP(``alpha``, ``base_cls(*args, **kwargs)``). |
| 25 | +# ``theta`` is the distribution indirectly returned by the DP. |
| 26 | +theta = base_cls(value=tf.cast(dp, tf.float32), **kwargs) |
| 27 | +print(theta) |
22 | 28 |
|
23 |
| - Only works for scalar alpha and scalar base distribution. |
24 |
| -
|
25 |
| - Parameters |
26 |
| - ---------- |
27 |
| - alpha : tf.Tensor |
28 |
| - Concentration parameter. Its shape determines the batch shape of the DP. |
29 |
| - base_cls : RandomVariable |
30 |
| - Class of base distribution. Its shape (when instantiated) |
31 |
| - determines the event shape of the DP. |
32 |
| - sample_n : int, optional |
33 |
| - Number of samples for each DP in the batch shape. |
34 |
| - *args, **kwargs : optional |
35 |
| - Arguments passed into ``base_cls``. |
36 |
| -
|
37 |
| - Returns |
38 |
| - ------- |
39 |
| - tf.Tensor |
40 |
| - A ``tf.Tensor`` of shape ``[sample_n] + batch_shape + event_shape``, |
41 |
| - where ``sample_n`` is the number of samples for each DP, |
42 |
| - ``batch_shape`` is the number of independent DPs, and |
43 |
| - ``event_shape`` is the shape of the base distribution. |
44 |
| - """ |
45 |
| - def cond(k, beta_k, draws, bools): |
46 |
| - # Proceed if at least one bool is True. |
47 |
| - return tf.reduce_any(bools) |
48 |
| - |
49 |
| - def body(k, beta_k, draws, bools): |
50 |
| - k = k + 1 |
51 |
| - beta_k = beta_k * Beta(a=1.0, b=alpha).sample(sample_n) |
52 |
| - theta_k = base_cls(*args, **kwargs).sample(sample_n) |
53 |
| - |
54 |
| - # Assign True samples to the new theta_k. |
55 |
| - draws = tf.where(bools, theta_k, draws) |
56 |
| - |
57 |
| - flips = tf.cast(Bernoulli(p=beta_k), tf.bool) |
58 |
| - bools = tf.logical_and(flips, tf.equal(draws, theta_k)) |
59 |
| - return k, beta_k, draws, bools |
60 |
| - |
61 |
| - k = 0 |
62 |
| - beta_k = Beta(a=1.0, b=alpha).sample(sample_n) |
63 |
| - theta_k = base_cls(*args, **kwargs).sample(sample_n) |
64 |
| - |
65 |
| - # Initialize all samples as theta_k. |
66 |
| - draws = theta_k |
67 |
| - # Flip ``sample_n`` coins, one for each sample. |
68 |
| - flips = tf.cast(Bernoulli(p=beta_k), tf.bool) |
69 |
| - # Get boolean tensor for samples that return heads |
70 |
| - # and are currently equal to theta_k. |
71 |
| - bools = tf.logical_and(flips, tf.equal(draws, theta_k)) |
72 |
| - |
73 |
| - total_sticks, _, samples, _ = tf.while_loop( |
74 |
| - cond, body, loop_vars=[k, beta_k, draws, bools]) |
75 |
| - return total_sticks, samples |
76 |
| - |
77 |
| - |
78 |
| -dp = dirichlet_process(0.1, Normal, mu=0.0, sigma=1.0) |
| 29 | +# Fetching theta is the same as fetching the Dirichlet process. |
79 | 30 | sess = tf.Session()
|
80 |
| -print(sess.run(dp)) |
81 |
| -print(sess.run(dp)) |
82 |
| -print(sess.run(dp)) |
| 31 | +print(sess.run([dp, theta])) |
| 32 | +print(sess.run([dp, theta])) |
| 33 | + |
| 34 | +# This also works for non-scalar base distributions. |
| 35 | +base_cls = Exponential |
| 36 | +kwargs = {'lam': tf.ones([5, 2])} |
| 37 | +dp = DirichletProcess(0.1, base_cls, **kwargs) |
| 38 | +print(dp) |
0 commit comments