-
Notifications
You must be signed in to change notification settings - Fork 751
Description
I recently discovered Edward and I’ve been working on getting some psychometric applications up and running, starting with a simple example, the Rasch model. The Rasch model is a simple model for binary matrices X which decomposes the logit of the probability that each X_{pi} = 1 into the difference between two components, a row component \theta_p and a column component \beta_i.
I’m attempting to implement the Rasch model using a class model wrapper. The model has two latent variables — a 200 x 1 tensor trait
containing the theta values and a 1 x 25 tensor thresh
containing the beta values. To track down why variational inference is diverging, I inserted the print
statement inside the log_prob
method. Here's the code
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import edward as ed
import matplotlib.pyplot as plt
from edward.models import Normal, Bernoulli
from edward.stats import norm
from scipy.special import expit
class RaschModel:
"""Full Rasch model with N(0, 1) priors on person (\theta_p)
and item parameters (\beta_i).
.. math::
log p(X_{pi} = 1) - log P(X_{pi} = 0) = \theta_p - \beta_i
"""
def log_prob(self, xs, zs):
X = xs['X']
trait, thresh = zs['trait'], zs['thresh']
print(trait, thresh, flush=True)
log_prior = tf.reduce_sum(norm.log_pdf(trait, 0.0, 1.0))
log_prior += tf.reduce_sum(norm.log_pdf(thresh, 0.0, 1.0))
logit = tf.sub(trait, thresh)
log_lik = tf.reduce_sum(tf.mul(X, logit))
log_lik -= tf.reduce_sum(tf.log(1 + tf.exp(logit)))
# log_like -= tf.reduce_sum(tf.nn.softplus(logit))
return log_prior + log_lik
def create_data(nsubj, nitem):
trait = np.random.normal(size=[nsubj, 1])
thresh = np.random.normal(size=[1, nitem])
X = np.random.binomial(1, expit(trait - thresh))
return locals()
nsubj = 200
nitem = 25
dataset = create_data(nsubj, nitem)
model = RaschModel()
q_trait = Normal(mu=tf.Variable(tf.random_normal([nsubj, 1])),
sigma=tf.nn.softplus(tf.Variable(tf.random_normal([nsubj, 1]))))
q_thresh = Normal(mu=tf.Variable(tf.random_normal([1, nitem])),
sigma=tf.nn.softplus(tf.Variable(tf.random_normal([1, nitem]))))
inference = ed.MFVI({'trait': q_trait, 'thresh': q_thresh},
{'X': dataset['X']}, model)
#inference.initialize()
inference.run(n_iter=2500, n_samples=10, n_minibatch=20)
plt.scatter(dataset['trait'], q_trait.mu.eval())
plt.show()
The output from my system is below. I would expect the first Tensor to have shape (200, 1) rather than (1, 25).
Tensor("inference_0/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_0/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_1/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_1/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_2/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_2/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_3/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_3/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_4/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_4/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_5/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_5/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_6/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_6/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_7/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_7/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_8/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_8/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
Tensor("inference_9/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32) Tensor("inference_9/Normal_1/sample/sample:0", shape=(1, 25), dtype=float32)
iter 1 loss 59674.46
iter 100 loss 7335925087847055360.00
W tensorflow/core/framework/op_kernel.cc:940] Invalid argument: assertion failed: [] [Condition x > 0 did not hold element-wise: x = ] [Softplus_37:0] [1.9073468e-06 1.7881378e-06 3.5762781e-07...]
[[Node: Normal_660/assert_positive/assert_less/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](Normal_660/assert_positive/assert_less/All, Normal_660/assert_positive/assert_less/Assert/data_0, Normal_660/assert_positive/assert_less/Assert/data_1, Normal_660/assert_positive/assert_less/Assert/data_2, Softplus_37)]]
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
964 try:
--> 965 return fn(*args)
966 except errors.OpError as e:
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
946 feed_dict, fetch_list, target_list,
--> 947 status, run_metadata)
948
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/contextlib.py in __exit__(self, type, value, traceback)
65 try:
---> 66 next(self.gen)
67 except StopIteration:
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/framework/errors.py in raise_exception_on_not_ok_status()
449 compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 450 pywrap_tensorflow.TF_GetCode(status))
451 finally:
InvalidArgumentError: assertion failed: [] [Condition x > 0 did not hold element-wise: x = ] [Softplus_37:0] [1.9073468e-06 1.7881378e-06 3.5762781e-07...]
[[Node: Normal_660/assert_positive/assert_less/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](Normal_660/assert_positive/assert_less/All, Normal_660/assert_positive/assert_less/Assert/data_0, Normal_660/assert_positive/assert_less/Assert/data_1, Normal_660/assert_positive/assert_less/Assert/data_2, Softplus_37)]]
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
/~/Documents/mzeigenfuse/Projects/irtvb/raschmodel.py in <module>()
55 {'X': dataset['X']}, model)
56 #inference.initialize()
---> 57 inference.run(n_iter=2500, n_samples=10, n_minibatch=20)
58
59 plt.scatter(dataset['trait'], q_trait.mu.eval())
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/edward/inferences/inference.py in run(self, logdir, variables, use_coordinator, *args, **kwargs)
173
174 for _ in range(self.n_iter):
--> 175 info_dict = self.update()
176 self.print_progress(info_dict)
177
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/edward/inferences/variational_inference.py in update(self, feed_dict)
139
140 sess = get_session()
--> 141 _, t, loss = sess.run([self.train, self.increment_t, self.loss], feed_dict)
142 return {'t': t, 'loss': loss}
143
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
708 try:
709 result = self._run(None, fetches, feed_dict, options_ptr,
--> 710 run_metadata_ptr)
711 if run_metadata:
712 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
906 if final_fetches or final_targets:
907 results = self._do_run(handle, final_targets, final_fetches,
--> 908 feed_dict_string, options, run_metadata)
909 else:
910 results = []
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
956 if handle is None:
957 return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
--> 958 target_list, options, run_metadata)
959 else:
960 return self._do_call(_prun_fn, self._session, handle, feed_dict,
/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
976 except KeyError:
977 pass
--> 978 raise type(e)(node_def, op, message)
979
980 def _extend_graph(self):
InvalidArgumentError: assertion failed: [] [Condition x > 0 did not hold element-wise: x = ] [Softplus_37:0] [1.9073468e-06 1.7881378e-06 3.5762781e-07...]
[[Node: Normal_660/assert_positive/assert_less/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](Normal_660/assert_positive/assert_less/All, Normal_660/assert_positive/assert_less/Assert/data_0, Normal_660/assert_positive/assert_less/Assert/data_1, Normal_660/assert_positive/assert_less/Assert/data_2, Softplus_37)]]
Caused by op 'Normal_660/assert_positive/assert_less/Assert', defined at:
File "/~/Documents/mzeigenfuse/anaconda/bin/ipython", line 6, in <module>
sys.exit(IPython.start_ipython())
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/__init__.py", line 119, in start_ipython
return launch_new_instance(argv=argv, **kwargs)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/traitlets/config/application.py", line 653, in launch_instance
app.start()
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/terminal/ipapp.py", line 348, in start
self.shell.mainloop()
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/terminal/interactiveshell.py", line 440, in mainloop
self.interact()
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/terminal/interactiveshell.py", line 431, in interact
self.run_cell(code, store_history=True)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2827, in run_ast_nodes
if self.run_code(code, result):
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-49-57e6ba9c8ee1>", line 1, in <module>
get_ipython().magic('run raschmodel.py')
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2158, in magic
return self.run_line_magic(magic_name, magic_arg_s)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2079, in run_line_magic
result = fn(*args,**kwargs)
File "<decorator-gen-57>", line 2, in run
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/magic.py", line 188, in <lambda>
call = lambda f, *a, **k: f(*a, **k)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/magics/execution.py", line 742, in run
run()
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/magics/execution.py", line 728, in run
exit_ignore=exit_ignore)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2481, in safe_execfile
self.compile if kw['shell_futures'] else None)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/IPython/utils/py3compat.py", line 186, in execfile
exec(compiler(f.read(), fname, 'exec'), glob, loc)
File "/~/Documents/mzeigenfuse/Projects/irtvb/raschmodel.py", line 52, in <module>
sigma=tf.nn.softplus(tf.Variable(tf.random_normal([1, nitem]))))
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/edward/models/random_variables.py", line 36, in __init__
RandomVariable.__init__(self, *args, **kwargs)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/edward/models/random_variable.py", line 58, in __init__
super(RandomVariable, self).__init__(*args, **kwargs)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/contrib/distributions/python/ops/normal.py", line 115, in __init__
validate_args else []):
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/ops/check_ops.py", line 178, in assert_positive
return assert_less(zero, x, data=data, summarize=summarize)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/ops/check_ops.py", line 354, in assert_less
return logging_ops.Assert(condition, data, summarize=summarize)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/ops/logging_ops.py", line 58, in Assert
return gen_logging_ops._assert(condition, data, summarize, name)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/ops/gen_logging_ops.py", line 37, in _assert
summarize=summarize, name=name)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 710, in apply_op
op_def=op_def)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2317, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/~/Documents/mzeigenfuse/anaconda/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1239, in __init__
self._traceback = _extract_stack()