Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions rstan/rstan/R/loo.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Leave-one-out cross-validation
#
#
# The \code{loo} method for stanfit objects ---a wrapper around the
# \code{loo.array} method from the \pkg{loo} package--- computes approximate
# leave-one-out cross-validation using Pareto smoothed importance sampling
# (PSIS-LOO CV).
#
#
# @param x stanfit object
# @param pars Name of parameter, transformed parameter, or generated quantity in
# the Stan program corresponding to the pointwise log-likelihood. If not
Expand All @@ -20,8 +20,9 @@
# @param k_threshold Threshold value for Pareto k values above which
# the moment matching algorithm is used. If \code{moment_match} is \code{FALSE},
# this is ignored.
# @param r_eff Whether to compute r_eff to pass to loo package.
# @param ... Ignored.
#
#
# @details Stan does not automatically compute and store the log-likelihood. It
# is up to the user to incorporate it into the Stan program if it is to be
# extracted after fitting the model. In a Stan model, the pointwise log
Expand All @@ -47,31 +48,37 @@ loo.stanfit <-
cores = getOption("mc.cores", 1),
moment_match = FALSE,
k_threshold = 0.7,
r_eff = FALSE,
...) {
stopifnot(length(pars) == 1L)
stopifnot(is.logical(save_psis))
stopifnot(is.logical(moment_match))
stopifnot(is.numeric(k_threshold))

stopifnot(is.logical(r_eff))

LLarray <- loo::extract_log_lik(stanfit = x,
parameter_name = pars,
merge_chains = FALSE)
r_eff <- loo::relative_eff(x = exp(LLarray), cores = cores)


if (!r_eff) {
r_eff <- NULL
} else {
r_eff <- loo::relative_eff(x = exp(LLarray), cores = cores)
}

if (moment_match) {
loo <- suppressWarnings(loo::loo.array(LLarray,
r_eff = r_eff,
cores = cores,
save_psis = save_psis))

x_array <- as.array(x)
chain_id <- rep(seq(dim(x_array)[2]),each = dim(x_array)[1])
loo <- loo_moment_match.stanfit(
x, loo = loo, chain_id = chain_id, k_threshold = k_threshold,
cores = cores, parameter_name = pars, ...
)
}
else {
} else {
loo <- loo::loo.array(LLarray,
r_eff = r_eff,
cores = cores,
Expand Down
8 changes: 8 additions & 0 deletions rstan/rstan/man/stanfit-method-loo.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ sampling (Vehtari, Gelman, and Gabry, 2017a,2017b).
\item{k_threshold}{Threshold value for Pareto k values above which
the moment matching algorithm is used. If \code{moment_match} is \code{FALSE},
this is ignored.}
\item{r_eff} \code{TRUE} or \code{FALSE} indicating whether to compute the
\code{r_eff} argument to pass to the \pkg{loo} package. If \code{TRUE},
will call \code{loo::relative_eff()}. If \code{FALSE}
(the default), we avoid computing \code{r_eff}, which can be very slow.
\code{r_eff} measures the amount of autocorrelation in MCMC draws, and is
used to compute more accurate ESS and MCSE estimates for pointwise and
total ELPDs. When \code{r_eff=FALSE}, the reported ESS and MCSE estimates
may be over-optimistic if the posterior draws are far from independent.
\item{\dots}{Ignored.}
}

Expand Down
Loading