|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import tensorflow as tf |
| 6 | + |
| 7 | +from edward.models.random_variable import RandomVariable |
| 8 | +from edward.models.random_variables import Bernoulli, Beta |
| 9 | +from tensorflow.contrib.distributions import Distribution |
| 10 | + |
| 11 | + |
| 12 | +class DirichletProcess(RandomVariable, Distribution): |
| 13 | + def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True, |
| 14 | + name="DirichletProcess", value=None, *args, **kwargs): |
| 15 | + """Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`. |
| 16 | +
|
| 17 | + It has two parameters: a positive real value :math:`\\alpha`, |
| 18 | + known as the concentration parameter (``alpha``), and a base |
| 19 | + distribution :math:`H` (``base_cls(*args, **kwargs)``). |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + alpha : tf.Tensor |
| 24 | + Concentration parameter. Must be positive real-valued. Its shape |
| 25 | + determines the number of independent DPs (batch shape). |
| 26 | + base_cls : RandomVariable |
| 27 | + Class of base distribution. Its shape (when instantiated) |
| 28 | + determines the shape of an individual DP (event shape). |
| 29 | + *args, **kwargs : optional |
| 30 | + Arguments passed into ``base_cls``. |
| 31 | +
|
| 32 | + Examples |
| 33 | + -------- |
| 34 | + >>> # scalar concentration parameter, scalar base distribution |
| 35 | + >>> dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0) |
| 36 | + >>> assert dp.get_shape() == () |
| 37 | + >>> |
| 38 | + >>> # vector of concentration parameters, matrix of Exponentials |
| 39 | + >>> dp = DirichletProcess(tf.constant([0.1, 0.4]), |
| 40 | + ... Exponential, lam=tf.ones([5, 3])) |
| 41 | + >>> assert dp.get_shape() == (2, 5, 3) |
| 42 | + """ |
| 43 | + with tf.name_scope(name, values=[alpha]) as ns: |
| 44 | + with tf.control_dependencies([]): |
| 45 | + self._alpha = tf.identity(alpha, name="alpha") |
| 46 | + self._base_cls = base_cls |
| 47 | + self._base_args = args |
| 48 | + self._base_kwargs = kwargs |
| 49 | + |
| 50 | + # Instantiate base for use in other methods such as `_get_event_shape`. |
| 51 | + self._base = self._base_cls(*self._base_args, **self._base_kwargs) |
| 52 | + |
| 53 | + super(DirichletProcess, self).__init__( |
| 54 | + dtype=tf.int32, |
| 55 | + parameters={"alpha": self._alpha, |
| 56 | + "base_cls": self._base_cls, |
| 57 | + "args": self._base_args, |
| 58 | + "kwargs": self._base_kwargs}, |
| 59 | + is_continuous=False, |
| 60 | + is_reparameterized=False, |
| 61 | + validate_args=validate_args, |
| 62 | + allow_nan_stats=allow_nan_stats, |
| 63 | + name=ns, |
| 64 | + value=value) |
| 65 | + |
| 66 | + @property |
| 67 | + def alpha(self): |
| 68 | + """Concentration parameter.""" |
| 69 | + return self._alpha |
| 70 | + |
| 71 | + def _batch_shape(self): |
| 72 | + return tf.convert_to_tensor(self.get_batch_shape()) |
| 73 | + |
| 74 | + def _get_batch_shape(self): |
| 75 | + return self._alpha.get_shape() |
| 76 | + |
| 77 | + def _event_shape(self): |
| 78 | + return tf.convert_to_tensor(self.get_event_shape()) |
| 79 | + |
| 80 | + def _get_event_shape(self): |
| 81 | + return self._base.get_shape() |
| 82 | + |
| 83 | + def _sample_n(self, n, seed=None): |
| 84 | + """Sample ``n`` draws from the DP. Draws from the base |
| 85 | + distribution are memoized across ``n``. |
| 86 | +
|
| 87 | + Draws from the base distribution are not memoized across the batch |
| 88 | + shape: i.e., each independent DP in the batch shape has its own |
| 89 | + memoized samples. Similarly, draws are not memoized across calls |
| 90 | + to ``sample()``. |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + tf.Tensor |
| 95 | + A ``tf.Tensor`` of shape ``[n] + batch_shape + event_shape``, |
| 96 | + where ``n`` is the number of samples for each DP, |
| 97 | + ``batch_shape`` is the number of independent DPs, and |
| 98 | + ``event_shape`` is the shape of the base distribution. |
| 99 | +
|
| 100 | + Notes |
| 101 | + ----- |
| 102 | + The implementation has only one inefficiency, which is that it |
| 103 | + draws (n, batch_shape) samples from the base distribution at each |
| 104 | + iteration of the while loop. Ideally, we would only draw new |
| 105 | + samples for those in the loop returning True. |
| 106 | + """ |
| 107 | + if seed is not None: |
| 108 | + raise NotImplementedError("seed is not implemented.") |
| 109 | + |
| 110 | + batch_shape = self._get_batch_shape().as_list() |
| 111 | + event_shape = self._get_event_shape().as_list() |
| 112 | + rank = 1 + len(batch_shape) + len(event_shape) |
| 113 | + # Note this is for scoping within the while loop's body function. |
| 114 | + self._temp_scope = [n, batch_shape, event_shape, rank] |
| 115 | + |
| 116 | + # First stick probability, one for each sample and each DP in the |
| 117 | + # batch shape. It has shape (n, batch_shape). |
| 118 | + beta_k = Beta(a=tf.ones_like(self._alpha), b=self._alpha).sample(n) |
| 119 | + # First base distribution. |
| 120 | + # It has shape (n, batch_shape, event_shape). |
| 121 | + theta_k = tf.tile( # make (batch_shape, event_shape), then memoize across n |
| 122 | + tf.expand_dims(self._base_cls(*self._base_args, **self._base_kwargs). |
| 123 | + sample(batch_shape), 0), |
| 124 | + [n] + [1] * (rank - 1)) |
| 125 | + |
| 126 | + # Initialize all samples as the first base distribution. |
| 127 | + draws = theta_k |
| 128 | + # Flip coins for each stick probability. |
| 129 | + flips = Bernoulli(p=beta_k) |
| 130 | + # Get boolean tensor, returning True for samples that return tails |
| 131 | + # and are currently equal to theta_k. |
| 132 | + # It has shape (n, batch_shape). |
| 133 | + bools = tf.logical_and( |
| 134 | + tf.cast(1 - flips, tf.bool), |
| 135 | + tf.reduce_all(tf.equal(draws, theta_k), # reduce event_shape |
| 136 | + [i for i in range(1 + len(batch_shape), rank)])) |
| 137 | + |
| 138 | + samples, _ = tf.while_loop(self._sample_n_cond, self._sample_n_body, |
| 139 | + loop_vars=[draws, bools]) |
| 140 | + return samples |
| 141 | + |
| 142 | + def _sample_n_cond(self, draws, bools): |
| 143 | + # Proceed if at least one bool is True. |
| 144 | + return tf.reduce_any(bools) |
| 145 | + |
| 146 | + def _sample_n_body(self, draws, bools): |
| 147 | + n, batch_shape, event_shape, rank = self._temp_scope |
| 148 | + |
| 149 | + beta_k = Beta(a=tf.ones_like(self._alpha), b=self._alpha).sample(n) |
| 150 | + theta_k = tf.tile( # make (batch_shape, event_shape), then memoize across n |
| 151 | + tf.expand_dims(self._base_cls(*self._base_args, **self._base_kwargs). |
| 152 | + sample(batch_shape), 0), |
| 153 | + [n] + [1] * (rank - 1)) |
| 154 | + |
| 155 | + if len(bools.get_shape()) > 1: |
| 156 | + # ``tf.where`` only index subsets when ``bools`` is at most a |
| 157 | + # vector. In general, ``bools`` has shape (n, batch_shape). |
| 158 | + # Therefore we tile ``bools`` to be of shape |
| 159 | + # (n, batch_shape, event_shape) in order to index per-element. |
| 160 | + bools = tf.tile(tf.reshape( |
| 161 | + bools, [n] + batch_shape + [1] * len(event_shape)), |
| 162 | + [1] + [1] * len(batch_shape) + event_shape) |
| 163 | + |
| 164 | + # Assign True samples to the new theta_k. |
| 165 | + draws = tf.where(bools, theta_k, draws) |
| 166 | + |
| 167 | + flips = Bernoulli(p=beta_k) |
| 168 | + bools = tf.logical_and( |
| 169 | + tf.cast(1 - flips, tf.bool), |
| 170 | + tf.reduce_all(tf.equal(draws, theta_k), # reduce event_shape |
| 171 | + [i for i in range(1 + len(batch_shape), rank)])) |
| 172 | + return draws, bools |
0 commit comments