Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Authors@R:
email = "[email protected]", comment = c(ORCID = "0000-0003-1878-3253")),
person(given = "Jacob", family = "Socolar", role = "ctb"),
person(given = "Andrew", family = "Johnson", role = "ctb",
comment = c(ORCID = "0000-0001-7000-8065 ")))
comment = c(ORCID = "0000-0001-7000-8065")))
Description: A lightweight interface to 'Stan' <https://mc-stan.org>.
The 'CmdStanR' interface is an alternative to 'RStan' that calls the command
line interface for compilation and running algorithms instead of interfacing
Expand All @@ -27,7 +27,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org
BugReports: https://github.com/stan-dev/cmdstanr/issues
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.2.0
RoxygenNote: 7.2.1
Roxygen: list(markdown = TRUE, r6 = FALSE)
SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan)
Depends:
Expand All @@ -46,5 +46,7 @@ Suggests:
loo (>= 2.0.0),
rlang (>= 0.4.7),
rmarkdown,
testthat (>= 2.1.0)
testthat (>= 2.1.0),
Rcpp,
RcppEigen
VignetteBuilder: knitr
3 changes: 3 additions & 0 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ CmdStanArgs <- R6::R6Class(
initialize = function(model_name,
stan_file = NULL,
stan_code = NULL,
model_methods_env = NULL,
exe_file,
proc_ids,
method_args,
Expand All @@ -43,6 +44,7 @@ CmdStanArgs <- R6::R6Class(
self$model_name <- model_name
self$stan_code <- stan_code
self$exe_file <- exe_file
self$model_methods_env <- model_methods_env
self$proc_ids <- proc_ids
self$data_file <- data_file
self$seed <- seed
Expand All @@ -52,6 +54,7 @@ CmdStanArgs <- R6::R6Class(
self$method <- self$method_args$method
self$save_latent_dynamics <- save_latent_dynamics
self$using_tempdir <- is.null(output_dir)
self$model_variables <- model_variables
if (getRversion() < "3.5.0") {
self$output_dir <- output_dir %||% tempdir()
} else {
Expand Down
210 changes: 209 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ CmdStanFit <- R6::R6Class(
initialize = function(runset) {
checkmate::assert_r6(runset, classes = "CmdStanRun")
self$runset <- runset
private$model_methods_env_ <- runset$model_methods_env()

if (!is.null(private$model_methods_env_$model_ptr)) {
initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
}
invisible(self)
},
num_procs = function() {
Expand Down Expand Up @@ -63,7 +68,8 @@ CmdStanFit <- R6::R6Class(
draws_ = NULL,
metadata_ = NULL,
init_ = NULL,
profiles_ = NULL
profiles_ = NULL,
model_methods_env_ = NULL
)
)

Expand Down Expand Up @@ -272,6 +278,208 @@ init <- function() {
}
CmdStanFit$set("public", name = "init", value = init)

#' Compile additional methods for accessing the model log-probability function
#' and parameter constraining and unconstraining. This requires the `Rcpp` package.
#'
#' @name fit-method-init_model_methods
#' @aliases init_model_methods
#' @description The `$init_model_methods()` compiles and initializes the
#' `log_prob`, `grad_log_prob`, `constrain_pars`, and `unconstrain_pars` functions.
#'
#' @param seed (integer) The random seed to use when initializing the model.
#' @param verbose (boolean) Whether to show verbose logging during compilation.
#' @param hessian (boolean) Whether to expose the (experimental) hessian method.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$init_model_methods()
#' }
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
if (length(private$model_methods_env_$hpp_code_) == 0) {
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
"the model must be compiled again", call. = FALSE)
}
if (hessian) {
message("The hessian method relies on higher-order autodiff ",
"which is still experimental. Please report any compilation ",
"errors that you encounter",
call. = FALSE)
}
message("Compiling additional model methods...")
if (is.null(private$model_methods_env_$model_ptr)) {
expose_model_methods(private$model_methods_env_, verbose, hessian)
}
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
invisible(NULL)
}
CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods)

#' Calculate the log-probability given a provided vector of unconstrained parameters.
#'
#' @name fit-method-log_prob
#' @aliases log_prob
#' @description The `$log_prob()` method provides access to the Stan model's `log_prob` function
#'
#' @param upars (numeric) A vector of unconstrained parameters to be passed to `log_prob`
#' @param jacobian_adjustment (bool) Whether to include the log-density adjustments from
#' un/constraining variables
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$log_prob(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
#' }
#'
log_prob <- function(upars, jacobian_adjustment = TRUE) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
if (length(upars) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(upars), " were provided!", call. = FALSE)
}
private$model_methods_env_$log_prob(private$model_methods_env_$model_ptr_, upars, jacobian_adjustment)
}
CmdStanFit$set("public", name = "log_prob", value = log_prob)

#' Calculate the log-probability and the gradient w.r.t. each input for a
#' given vector of unconstrained parameters
#'
#' @name fit-method-grad_log_prob
#' @aliases grad_log_prob
#' @description The `$grad_log_prob()` method provides access to the
#' Stan model's `log_prob` function and its derivative
#'
#' @param upars (numeric) A vector of unconstrained parameters to be passed
#' to `grad_log_prob`
#' @param jacobian_adjustment (bool) Whether to include the log-density adjustments from
#' un/constraining variables
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$grad_log_prob(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
#' }
#'
grad_log_prob <- function(upars, jacobian_adjustment = TRUE) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
if (length(upars) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(upars), " were provided!", call. = FALSE)
}
private$model_methods_env_$grad_log_prob(private$model_methods_env_$model_ptr_, upars, jacobian_adjustment)
}
CmdStanFit$set("public", name = "grad_log_prob", value = grad_log_prob)

#' Calculate the log-probability , the gradient w.r.t. each input, and the hessian
#' for a given vector of unconstrained parameters
#'
#' @name fit-method-hessian
#' @aliases hessian
#' @description The `$hessian()` method provides access to the
#' Stan model's `log_prob`, its derivative, and its hessian
#'
#' @param upars (numeric) A vector of unconstrained parameters to be passed
#' to `hessian`
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$hessian(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
#' }
#'
hessian <- function(upars) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
if (length(upars) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(upars), " were provided!", call. = FALSE)
}
private$model_methods_env_$hessian(private$model_methods_env_$model_ptr_, upars)
}
CmdStanFit$set("public", name = "hessian", value = hessian)

#' Transform a set of parameter values to the unconstrained scale
#'
#' @name fit-method-unconstrain_pars
#' @aliases unconstrain_pars
#' @description The `$unconstrain_pars()` method transforms input parameters to
#' the unconstrained scale
#'
#' @param pars (list) A list of parameter values to transform, in the same format as
#' provided to the `init` argument of the `$sample()` method
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$unconstrain_pars(list(alpha = 0.5, beta = c(0.7, 1.1, 0.2)))
#' }
#'
unconstrain_pars <- function(pars) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
model_par_names <- names(self$runset$args$model_variables$parameters)
prov_par_names <- names(pars)

prov_pars_not_in_model <- which(!(prov_par_names %in% model_par_names))
if (length(prov_pars_not_in_model) > 0) {
stop("Provided parameter(s): ", paste(prov_par_names[prov_pars_not_in_model], collapse = ","),
" not present in model!", call. = FALSE)
}

model_pars_not_prov <- which(!(model_par_names %in% prov_par_names))
if (length(model_pars_not_prov) > 0) {
stop("Model parameter(s): ", paste(model_par_names[model_pars_not_prov], collapse = ","),
" not provided!", call. = FALSE)
}

stan_pars <- process_init_list(list(pars), num_procs = 1, self$runset$args$model_variables)
private$model_methods_env_$unconstrain_pars(private$model_methods_env_$model_ptr_, stan_pars)
}
CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars)

#' Transform a set of unconstrained parameter values to the constrained scale
#'
#' @name fit-method-constrain_pars
#' @aliases constrain_pars
#' @description The `$constrain_pars()` method transforms input parameters to
#' the constrained scale
#'
#' @param upars (numeric) A vector of unconstrained parameters to constrain
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
#' }
#'
constrain_pars <- function(upars) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
if (length(upars) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(upars), " were provided!", call. = FALSE)
}
cpars <- private$model_methods_env_$constrain_pars(private$model_methods_env_$model_ptr_, private$model_methods_env_$model_rng_, upars)
skeleton <- create_skeleton(self$runset$args$model_variables)
utils::relist(cpars, skeleton)
}
CmdStanFit$set("public", name = "constrain_pars", value = constrain_pars)

#' Extract log probability (target)
#'
#' @name fit-method-lp
Expand Down
20 changes: 20 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ CmdStanModel <- R6::R6Class(
model_name_ = character(),
exe_file_ = character(),
hpp_file_ = character(),
model_methods_env_ = NULL,
dir_ = NULL,
cpp_options_ = list(),
stanc_options_ = list(),
Expand Down Expand Up @@ -382,6 +383,10 @@ CmdStanModel <- R6::R6Class(
#' @param force_recompile (logical) Should the model be recompiled even if was
#' not modified since last compiled. The default is `FALSE`. Can also be set
#' via a global `cmdstanr_force_recompile` option.
#' @param compile_model_methods (logical) Compile additional model methods
#' (`log_prob()`, `grad_log_prob()`, `constrain_pars()`, `unconstrain_pars()`)
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
#' be compiled with the model methods?
#' @param threads Deprecated and will be removed in a future release. Please
#' turn on threading via `cpp_options = list(stan_threads = TRUE)` instead.
#'
Expand Down Expand Up @@ -431,6 +436,8 @@ compile <- function(quiet = TRUE,
cpp_options = list(),
stanc_options = list(),
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
compile_model_methods = FALSE,
compile_hessian_method = FALSE,
#deprecated
threads = FALSE) {
if (length(self$stan_file()) == 0) {
Expand Down Expand Up @@ -613,6 +620,13 @@ compile <- function(quiet = TRUE,
private$precompile_cpp_options_ <- NULL
private$precompile_stanc_options_ <- NULL
private$precompile_include_paths_ <- NULL
private$model_methods_env_ <- new.env()
private$model_methods_env_$hpp_code_ <- readLines(private$hpp_file_)
if (compile_model_methods) {
expose_model_methods(env = private$model_methods_env_,
verbose = !quiet,
hessian = compile_hessian_method)
}
invisible(self)
}
CmdStanModel$set("public", name = "compile", value = compile)
Expand Down Expand Up @@ -1102,6 +1116,7 @@ sample <- function(data = NULL,
method_args = sample_args,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1259,6 +1274,7 @@ sample_mpi <- function(data = NULL,
method_args = sample_args,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1370,6 +1386,7 @@ optimize <- function(data = NULL,
method_args = optimize_args,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1487,6 +1504,7 @@ variational <- function(data = NULL,
method_args = variational_args,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1603,6 +1621,7 @@ generate_quantities <- function(fitted_params,
method_args = gq_args,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = seq_along(fitted_params_files),
Expand Down Expand Up @@ -1666,6 +1685,7 @@ diagnose <- function(data = NULL,
method_args = diagnose_args,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down
1 change: 1 addition & 0 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ CmdStanRun <- R6::R6Class(
proc_ids = function() self$procs$proc_ids(),
exe_file = function() self$args$exe_file,
stan_code = function() self$args$stan_code,
model_methods_env = function() self$args$model_methods_env,
model_name = function() self$args$model_name,
method = function() self$args$method,
data_file = function() self$args$data_file,
Expand Down
Loading