Skip to content
Merged
Binary file added docs/images/dirichlet-process-fig0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/dirichlet-process-fig1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 48 additions & 12 deletions docs/tex/api/model-compositionality.tex
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,57 @@ \subsubsection{Bayesian Nonparametrics}
collapsing the infinite-dimensional space and lazily defining the
infinite-dimensional space.

For the collapsed approach, see the
In the collapsed approach, we specify a distribution over its
instantiation, and the stochastic process is implicitly marginalized
out. For example, we can represent a distribution over the function
evaluations of a Gaussian process, and not explicitly represent the
function draw.

\begin{lstlisting}[language=Python]
from edward.models import Bernoulli, Normal

def kernel(X):
"""Evaluate kernel over each pair of rows (data points) in the matrix."""
return

X = tf.placeholder(tf.float32, [N, D])
y = MultivariateNormalFull(mu=tf.zeros(N), sigma=kernel(X))
\end{lstlisting}

For more details, see the
\href{/tutorials/supervised-classification}{Gaussian process classification}
tutorial as an example. We specify distributions over the function
evaluations of the Gaussian process, and the Gaussian process is
implicitly marginalized out. This approach is also useful for Poisson
process models.
tutorial. This approach is also useful for Poisson process models.

To work directly on the infinite-dimensional space, one can leverage
random variables with
In the lazy approach, we work directly on the infinite-dimensional space via
\href{https://www.tensorflow.org/versions/master/api_docs/python/control_flow_ops.html}{control flow operations}
in TensorFlow. At runtime, the control flow will lazily define any
parameters in the space necessary in order to generate samples. As an
example, we use a while loop to define a
\href{https://github.com/blei-lab/edward/blob/master/examples/pp_dirichlet_process.py}
{Dirichlet process} according to its stick breaking representation.
in TensorFlow. At runtime, the control flow will execute only the
necessary computation in order to terminate. As an example, Edward
provides a \texttt{DirichletProcess} random variable.

\begin{lstlisting}[language=Python]
from edward.models import DirichletProcess, Normal

def plot_dirichlet_process(alpha):
with tf.Session() as sess:
dp = DirichletProcess(alpha, Normal, mu=0.0, sigma=1.0)
samples = sess.run(dp.sample(1000))
plt.hist(samples, bins=100, range=(-3.0, 3.0))
plt.title("DP({0}, N(0, 1))".format(alpha))
plt.show()

# Dirichlet process with high concentration
plot_dirichlet_process(alpha=1.0)
# Dirichlet process with low concentration (more spread out)
plot_dirichlet_process(alpha=50.0)
\end{lstlisting}

\includegraphics[width=350px]{/images/dirichlet-process-fig0.png}
\includegraphics[width=350px]{/images/dirichlet-process-fig1.png}

To see the essential component defining the \texttt{DirichletProcess}, see
\href{https://github.com/blei-lab/edward/blob/master/examples/pp_dirichlet_process.py}{\texttt{examples/pp_dirichlet_process.py}}
in the Github repository. Its source implementation can be found at
\href{https://github.com/blei-lab/edward/blob/master/edward/models/dirichlet_process.py}{\texttt{edward/models/dirichlet_process.py}}.

\subsubsection{Probabilistic Programs}

Expand Down
2 changes: 2 additions & 0 deletions docs/tex/api/model.tex
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ \subsubsection{Model}
.. autoclass:: edward.models.RandomVariable
:members:

.. autoclass:: edward.models.DirichletProcess

.. autoclass:: edward.models.Empirical

.. autoclass:: edward.models.PointMass
Expand Down
3 changes: 2 additions & 1 deletion docs/tex/api/reference.tex
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ \subsubsection{Models}
{%sphinx

* :class:`edward.models.RandomVariable`
* :class:`edward.models.DirichletProcess`
* :class:`edward.models.Empirical`
* :class:`edward.models.PointMass`
* {{tensorflow_distributions}}
Expand Down Expand Up @@ -62,4 +63,4 @@ \subsubsection{Criticism}
* :func:`edward.criticisms.evaluate`
* :func:`edward.criticisms.ppc`

%}
%}
21 changes: 11 additions & 10 deletions docs/tex/iclr2017.tex
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ \subsection{Deep Probabilistic Programming}
The code snippets assume the following versions.

\begin{lstlisting}[language=bash]
pip install edward==1.2.3
pip install -e "git+https://github.com/blei-lab/edward.git#egg=edward"
pip install tensorflow==1.0.0 # alternatively, tensorflow-gpu==1.0.0
pip install keras==1.0.0
\end{lstlisting}
Expand Down Expand Up @@ -331,19 +331,20 @@ \subsubsection{Appendix A. Model Examples}
\textbf{Figure 13}. Dirichlet process mixture model \citep{antoniak1974mixtures}.
\begin{lstlisting}[language=python]
import tensorflow as tf
from edward.models import Normal
from edward.models import DirichletProcess, Normal

N = 1000 # number of data points
D = 5 # data dimensionality

mu = DirichletProcess( # see script for implementation
alpha=0.1, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D), sample_n=N)
dp = DirichletProcess(
alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D))
mu = dp.sample(N)
x = Normal(mu=mu, sigma=tf.ones([N, D]))
\end{lstlisting}
To see the essential component defining the \texttt{DirichletProcess}
random variable, see
To see the essential component defining the \texttt{DirichletProcess}, see
\href{https://github.com/blei-lab/edward/blob/master/examples/pp_dirichlet_process.py}{\texttt{examples/pp_dirichlet_process.py}}
in the Github repository.
A more involved version with a base distribution is available at
\href{https://github.com/blei-lab/edward/blob/master/examples/pp_dirichlet_process_base.py}{\texttt{examples/pp_dirichlet_process_base.py}}
in the Github repository.
in the Github repository. Its source implementation can be found at
\href{https://github.com/blei-lab/edward/blob/master/edward/models/dirichlet_process.py}{\texttt{edward/models/dirichlet_process.py}}.

\subsubsection{Appendix B. Inference Examples}

Expand Down
1 change: 1 addition & 0 deletions edward/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

from edward.models.dirichlet_process import *
from edward.models.empirical import *
from edward.models.point_mass import *
from edward.models.random_variable import *
Expand Down
172 changes: 172 additions & 0 deletions edward/models/dirichlet_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from edward.models.random_variable import RandomVariable
from edward.models.random_variables import Bernoulli, Beta
from tensorflow.contrib.distributions import Distribution


class DirichletProcess(RandomVariable, Distribution):
def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True,
name="DirichletProcess", value=None, *args, **kwargs):
"""Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`.

It has two parameters: a positive real value :math:`\\alpha`,
known as the concentration parameter (``alpha``), and a base
distribution :math:`H` (``base_cls(*args, **kwargs)``).

Parameters
----------
alpha : tf.Tensor
Concentration parameter. Must be positive real-valued. Its shape
determines the number of independent DPs (batch shape).
base_cls : RandomVariable
Class of base distribution. Its shape (when instantiated)
determines the shape of an individual DP (event shape).
*args, **kwargs : optional
Arguments passed into ``base_cls``.

Examples
--------
>>> # scalar concentration parameter, scalar base distribution
>>> dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0)
>>> assert dp.get_shape() == ()
>>>
>>> # vector of concentration parameters, matrix of Exponentials
>>> dp = DirichletProcess(tf.constant([0.1, 0.4]),
... Exponential, lam=tf.ones([5, 3]))
>>> assert dp.get_shape() == (2, 5, 3)
"""
with tf.name_scope(name, values=[alpha]) as ns:
with tf.control_dependencies([]):
self._alpha = tf.identity(alpha, name="alpha")
self._base_cls = base_cls
self._base_args = args
self._base_kwargs = kwargs

# Instantiate base for use in other methods such as `_get_event_shape`.
self._base = self._base_cls(*self._base_args, **self._base_kwargs)

super(DirichletProcess, self).__init__(
dtype=tf.int32,
parameters={"alpha": self._alpha,
"base_cls": self._base_cls,
"args": self._base_args,
"kwargs": self._base_kwargs},
is_continuous=False,
is_reparameterized=False,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns,
value=value)

@property
def alpha(self):
"""Concentration parameter."""
return self._alpha

def _batch_shape(self):
return tf.convert_to_tensor(self.get_batch_shape())

def _get_batch_shape(self):
return self._alpha.get_shape()

def _event_shape(self):
return tf.convert_to_tensor(self.get_event_shape())

def _get_event_shape(self):
return self._base.get_shape()

def _sample_n(self, n, seed=None):
"""Sample ``n`` draws from the DP. Draws from the base
distribution are memoized across ``n``.

Draws from the base distribution are not memoized across the batch
shape: i.e., each independent DP in the batch shape has its own
memoized samples. Similarly, draws are not memoized across calls
to ``sample()``.

Returns
-------
tf.Tensor
A ``tf.Tensor`` of shape ``[n] + batch_shape + event_shape``,
where ``n`` is the number of samples for each DP,
``batch_shape`` is the number of independent DPs, and
``event_shape`` is the shape of the base distribution.

Notes
-----
The implementation has only one inefficiency, which is that it
draws (n, batch_shape) samples from the base distribution at each
iteration of the while loop. Ideally, we would only draw new
samples for those in the loop returning True.
"""
if seed is not None:
raise NotImplementedError("seed is not implemented.")

batch_shape = self._get_batch_shape().as_list()
event_shape = self._get_event_shape().as_list()
rank = 1 + len(batch_shape) + len(event_shape)
# Note this is for scoping within the while loop's body function.
self._temp_scope = [n, batch_shape, event_shape, rank]

# First stick probability, one for each sample and each DP in the
# batch shape. It has shape (n, batch_shape).
beta_k = Beta(a=tf.ones_like(self._alpha), b=self._alpha).sample(n)
# First base distribution.
# It has shape (n, batch_shape, event_shape).
theta_k = tf.tile( # make (batch_shape, event_shape), then memoize across n
tf.expand_dims(self._base_cls(*self._base_args, **self._base_kwargs).
sample(batch_shape), 0),
[n] + [1] * (rank - 1))

# Initialize all samples as the first base distribution.
draws = theta_k
# Flip coins for each stick probability.
flips = Bernoulli(p=beta_k)
# Get boolean tensor, returning True for samples that return tails
# and are currently equal to theta_k.
# It has shape (n, batch_shape).
bools = tf.logical_and(
tf.cast(1 - flips, tf.bool),
tf.reduce_all(tf.equal(draws, theta_k), # reduce event_shape
[i for i in range(1 + len(batch_shape), rank)]))

samples, _ = tf.while_loop(self._sample_n_cond, self._sample_n_body,
loop_vars=[draws, bools])
return samples

def _sample_n_cond(self, draws, bools):
# Proceed if at least one bool is True.
return tf.reduce_any(bools)

def _sample_n_body(self, draws, bools):
n, batch_shape, event_shape, rank = self._temp_scope

beta_k = Beta(a=tf.ones_like(self._alpha), b=self._alpha).sample(n)
theta_k = tf.tile( # make (batch_shape, event_shape), then memoize across n
tf.expand_dims(self._base_cls(*self._base_args, **self._base_kwargs).
sample(batch_shape), 0),
[n] + [1] * (rank - 1))

if len(bools.get_shape()) > 1:
# ``tf.where`` only index subsets when ``bools`` is at most a
# vector. In general, ``bools`` has shape (n, batch_shape).
# Therefore we tile ``bools`` to be of shape
# (n, batch_shape, event_shape) in order to index per-element.
bools = tf.tile(tf.reshape(
bools, [n] + batch_shape + [1] * len(event_shape)),
[1] + [1] * len(batch_shape) + event_shape)

# Assign True samples to the new theta_k.
draws = tf.where(bools, theta_k, draws)

flips = Bernoulli(p=beta_k)
bools = tf.logical_and(
tf.cast(1 - flips, tf.bool),
tf.reduce_all(tf.equal(draws, theta_k), # reduce event_shape
[i for i in range(1 + len(batch_shape), rank)]))
return draws, bools
Loading