Skip to content

Commit 6419751

Browse files
siddharth-agrawaldustinvtran
authored andcommitted
Add regularization losses in implicit_klqp, map, wake_sleep (#823)
* Add regularization losses in implicit_klqp, map, wake_sleep * Add test for MAP regularization
1 parent 38381ac commit 6419751

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

edward/inferences/implicit_klqp.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class ImplicitKLqp(GANInference):
4848
+ If `scale` has more than one item, then in order to scale
4949
its corresponding output, `discriminator` must output a
5050
dictionary of same size and keys as `scale`.
51+
52+
The objective function also adds to itself a summation over all
53+
tensors in the `REGULARIZATION_LOSSES` collection.
5154
"""
5255
def __init__(self, latent_vars, data=None, discriminator=None,
5356
global_vars=None):
@@ -203,8 +206,14 @@ def build_loss_and_gradients(self, var_list):
203206
for key in six.iterkeys(self.scale)]
204207
scaled_ratio = tf.reduce_sum(scaled_ratio)
205208

209+
reg_terms_d = tf.losses.get_regularization_losses(scope="Disc")
210+
reg_terms_all = tf.losses.get_regularization_losses()
211+
reg_terms = [r for r in reg_terms_all if r not in reg_terms_d]
212+
206213
# Form variational objective.
207-
loss = -(pbeta_log_prob - qbeta_log_prob + scaled_ratio)
214+
loss = -(pbeta_log_prob - qbeta_log_prob + scaled_ratio -
215+
tf.reduce_sum(reg_terms))
216+
loss_d = loss_d + tf.reduce_sum(reg_terms_d)
208217

209218
var_list_d = tf.get_collection(
210219
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc")

edward/inferences/map.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ class MAP(VariationalInference):
7171
unconstrained; see, e.g., `qsigma` above. This is different than
7272
performing MAP on the unconstrained space: in general, the MAP of
7373
the transform is not the transform of the MAP.
74+
75+
The objective function also adds to itself a summation over all
76+
tensors in the `REGULARIZATION_LOSSES` collection.
7477
"""
7578
def __init__(self, latent_vars=None, data=None):
7679
"""Create an inference algorithm.
@@ -142,7 +145,8 @@ def build_loss_and_gradients(self, var_list):
142145
p_log_prob += tf.reduce_sum(
143146
self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))
144147

145-
loss = -p_log_prob
148+
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())
149+
loss = -p_log_prob + reg_penalty
146150

147151
grads = tf.gradients(loss, var_list)
148152
grads_and_vars = list(zip(grads, var_list))

edward/inferences/wake_sleep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class WakeSleep(VariationalInference):
5151
5252
where $z^{(s)} \sim q(z; \lambda)$ and $\\beta^{(s)}
5353
\sim q(\\beta)$.
54+
55+
The objective function also adds to itself a summation over all
56+
tensors in the `REGULARIZATION_LOSSES` collection.
5457
"""
5558
def __init__(self, *args, **kwargs):
5659
super(WakeSleep, self).__init__(*args, **kwargs)
@@ -129,15 +132,18 @@ def build_loss_and_gradients(self, var_list):
129132

130133
p_log_prob = tf.reduce_mean(p_log_prob)
131134
q_log_prob = tf.reduce_mean(q_log_prob)
135+
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())
132136

133137
if self.logging:
134138
tf.summary.scalar("loss/p_log_prob", p_log_prob,
135139
collections=[self._summary_key])
136140
tf.summary.scalar("loss/q_log_prob", q_log_prob,
137141
collections=[self._summary_key])
142+
tf.summary.scalar("loss/reg_penalty", reg_penalty,
143+
collections=[self._summary_key])
138144

139-
loss_p = -p_log_prob
140-
loss_q = -q_log_prob
145+
loss_p = -p_log_prob + reg_penalty
146+
loss_q = -q_log_prob + reg_penalty
141147

142148
q_rvs = list(six.itervalues(self.latent_vars))
143149
q_vars = [v for v in var_list

tests/inferences/test_map.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ def test_normalnormal_run(self):
2626

2727
self.assertAllClose(qmu.mean().eval(), 0)
2828

29+
def test_normalnormal_regularization(self):
30+
with self.test_session() as sess:
31+
x_data = np.array([5.0] * 50, dtype=np.float32)
32+
33+
mu = Normal(loc=0.0, scale=1.0)
34+
x = Normal(loc=mu, scale=1.0, sample_shape=50)
35+
36+
qmu = PointMass(params=tf.Variable(1.0))
37+
38+
inference = ed.MAP({mu: qmu}, data={x: x_data})
39+
inference.run(n_iter=1000)
40+
mu_val = qmu.mean().eval()
41+
42+
# regularized solution
43+
regularizer = tf.contrib.layers.l2_regularizer(scale=1.0)
44+
mu_reg = tf.get_variable("mu_reg", shape=[],
45+
regularizer=regularizer)
46+
x_reg = Normal(loc=mu_reg, scale=1.0, sample_shape=50)
47+
48+
inference_reg = ed.MAP(None, data={x_reg: x_data})
49+
inference_reg.run(n_iter=1000)
50+
51+
mu_reg_val = mu_reg.eval()
52+
self.assertAllClose(mu_val, mu_reg_val)
53+
2954
if __name__ == '__main__':
3055
ed.set_seed(42)
3156
tf.test.main()

0 commit comments

Comments
 (0)