Skip to content

Commit ceb537d

Browse files
authored
Make base inferences abstract classes and decorate static methods (#582)
* make base inferences abstract classes; decorate static methods * update docstrings
1 parent d6ef083 commit ceb537d

File tree

4 files changed

+50
-8
lines changed

4 files changed

+50
-8
lines changed

edward/inferences/inference.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import abc
56
import numpy as np
67
import six
78
import tensorflow as tf
@@ -10,8 +11,21 @@
1011
from edward.util import check_data, check_latent_vars, get_session, Progbar
1112

1213

14+
@six.add_metaclass(abc.ABCMeta)
1315
class Inference(object):
14-
"""Base class for Edward inference methods.
16+
"""Abstract base class for inference. All inference algorithms in
17+
Edward inherit from ``Inference``, sharing common methods and
18+
properties via a class hierarchy.
19+
20+
Specific algorithms typically inherit from other subclasses of
21+
``Inference`` rather than ``Inference`` directly. For example, one
22+
might inherit from the abstract classes ``MonteCarlo`` or
23+
``VariationalInference``.
24+
25+
To build an algorithm inheriting from ``Inference``, one must at the
26+
minimum implement ``initialize`` and ``update``: the former builds
27+
the computational graph for the algorithm; the latter runs the
28+
computational graph for the algorithm.
1529
"""
1630
def __init__(self, latent_vars=None, data=None):
1731
"""Initialization.
@@ -130,12 +144,15 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs):
130144
self.coord.request_stop()
131145
self.coord.join(self.threads)
132146

147+
@abc.abstractmethod
133148
def initialize(self, n_iter=1000, n_print=None, scale=None, logdir=None,
134149
debug=False):
135150
"""Initialize inference algorithm. It initializes hyperparameters
136151
and builds ops for the algorithm's computational graph. No ops
137152
should be created outside the call to ``initialize()``.
138153
154+
Any derived class of ``Inference`` **must** implement this method.
155+
139156
Parameters
140157
----------
141158
n_iter : int, optional
@@ -186,9 +203,12 @@ def initialize(self, n_iter=1000, n_print=None, scale=None, logdir=None,
186203
if self.debug:
187204
self.op_check = tf.add_check_numerics_ops()
188205

206+
@abc.abstractmethod
189207
def update(self, feed_dict=None):
190208
"""Run one iteration of inference.
191209
210+
Any derived class of ``Inference`` **must** implement this method.
211+
192212
Parameters
193213
----------
194214
feed_dict : dict, optional

edward/inferences/monte_carlo.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import abc
56
import numpy as np
67
import six
78
import tensorflow as tf
@@ -11,8 +12,14 @@
1112
from edward.util import get_session
1213

1314

15+
@six.add_metaclass(abc.ABCMeta)
1416
class MonteCarlo(Inference):
15-
"""Base class for Monte Carlo inference methods.
17+
"""Abstract base class for Monte Carlo. Specific Monte Carlo methods
18+
inherit from ``MonteCarlo``, sharing methods in this class.
19+
20+
To build an algorithm inheriting from ``MonteCarlo``, one must at the
21+
minimum implement ``build_update``: it determines how to assign
22+
the samples in the ``Empirical`` approximations.
1623
"""
1724
def __init__(self, latent_vars=None, data=None):
1825
"""Initialization.
@@ -138,12 +145,12 @@ def print_progress(self, info_dict):
138145
if t == 1 or t % self.n_print == 0:
139146
self.progbar.update(t, {'Acceptance Rate': info_dict['accept_rate']})
140147

148+
@abc.abstractmethod
141149
def build_update(self):
142-
"""Build update, which returns an assign op for parameters in
143-
the Empirical random variables.
150+
"""Build update rules, returning an assign op for parameters in
151+
the ``Empirical`` random variables.
144152
145-
Any derived class of ``MonteCarlo`` **must** implement
146-
this method.
153+
Any derived class of ``MonteCarlo`` **must** implement this method.
147154
148155
Raises
149156
------

edward/inferences/variational_inference.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import abc
56
import numpy as np
67
import six
78
import tensorflow as tf
@@ -16,8 +17,16 @@
1617
pass
1718

1819

20+
@six.add_metaclass(abc.ABCMeta)
1921
class VariationalInference(Inference):
20-
"""Base class for variational inference methods.
22+
"""Abstract base class for variational inference. Specific
23+
variational inference methods inherit from ``VariationalInference``,
24+
sharing methods such as a default optimizer.
25+
26+
To build an algorithm inheriting from ``VariaitonalInference``, one
27+
must at the minimum implement ``build_loss_and_gradients``: it
28+
determines the loss function and gradients to apply for a given
29+
optimizer.
2130
"""
2231
def __init__(self, *args, **kwargs):
2332
super(VariationalInference, self).__init__(*args, **kwargs)
@@ -150,8 +159,10 @@ def print_progress(self, info_dict):
150159
if t == 1 or t % self.n_print == 0:
151160
self.progbar.update(t, {'Loss': info_dict['loss']})
152161

162+
@abc.abstractmethod
153163
def build_loss_and_gradients(self, var_list):
154-
"""Build loss function.
164+
"""Build loss function and its gradients. They will be leveraged
165+
in an optimizer to update the model and variational parameters.
155166
156167
Any derived class of ``VariationalInference`` **must** implement
157168
this method.

edward/models/random_variable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,15 +291,19 @@ def get_shape(self):
291291
"""Get shape of random variable."""
292292
return self.shape
293293

294+
@staticmethod
294295
def _session_run_conversion_fetch_function(tensor):
295296
return ([tensor.value()], lambda val: val[0])
296297

298+
@staticmethod
297299
def _session_run_conversion_feed_function(feed, feed_val):
298300
return [(feed.value(), feed_val)]
299301

302+
@staticmethod
300303
def _session_run_conversion_feed_function_for_partial_run(feed):
301304
return [feed.value()]
302305

306+
@staticmethod
303307
def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False):
304308
_ = name
305309
if dtype and not dtype.is_compatible_with(v.dtype):

0 commit comments

Comments
 (0)