Skip to content

Commit f01934c

Browse files
authored
Merge pull request #701 from andrjohns/add-model-methods
Add optional methods for log_prob, grad_log_prob, hessian, un/constrain pars
2 parents 93d6ea7 + 99edeba commit f01934c

18 files changed

+763
-16
lines changed

DESCRIPTION

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Authors@R:
1515
email = "[email protected]", comment = c(ORCID = "0000-0003-1878-3253")),
1616
person(given = "Jacob", family = "Socolar", role = "ctb"),
1717
person(given = "Andrew", family = "Johnson", role = "ctb",
18-
comment = c(ORCID = "0000-0001-7000-8065 ")))
18+
comment = c(ORCID = "0000-0001-7000-8065")))
1919
Description: A lightweight interface to 'Stan' <https://mc-stan.org>.
2020
The 'CmdStanR' interface is an alternative to 'RStan' that calls the command
2121
line interface for compilation and running algorithms instead of interfacing
@@ -27,7 +27,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org
2727
BugReports: https://github.com/stan-dev/cmdstanr/issues
2828
Encoding: UTF-8
2929
LazyData: true
30-
RoxygenNote: 7.2.0
30+
RoxygenNote: 7.2.1
3131
Roxygen: list(markdown = TRUE, r6 = FALSE)
3232
SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan)
3333
Depends:
@@ -46,5 +46,7 @@ Suggests:
4646
loo (>= 2.0.0),
4747
rlang (>= 0.4.7),
4848
rmarkdown,
49-
testthat (>= 2.1.0)
49+
testthat (>= 2.1.0),
50+
Rcpp,
51+
RcppEigen
5052
VignetteBuilder: knitr

R/args.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ CmdStanArgs <- R6::R6Class(
2626
initialize = function(model_name,
2727
stan_file = NULL,
2828
stan_code = NULL,
29+
model_methods_env = NULL,
2930
exe_file,
3031
proc_ids,
3132
method_args,
@@ -43,6 +44,7 @@ CmdStanArgs <- R6::R6Class(
4344
self$model_name <- model_name
4445
self$stan_code <- stan_code
4546
self$exe_file <- exe_file
47+
self$model_methods_env <- model_methods_env
4648
self$proc_ids <- proc_ids
4749
self$data_file <- data_file
4850
self$seed <- seed
@@ -52,6 +54,7 @@ CmdStanArgs <- R6::R6Class(
5254
self$method <- self$method_args$method
5355
self$save_latent_dynamics <- save_latent_dynamics
5456
self$using_tempdir <- is.null(output_dir)
57+
self$model_variables <- model_variables
5558
if (getRversion() < "3.5.0") {
5659
self$output_dir <- output_dir %||% tempdir()
5760
} else {

R/fit.R

Lines changed: 209 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ CmdStanFit <- R6::R6Class(
1212
initialize = function(runset) {
1313
checkmate::assert_r6(runset, classes = "CmdStanRun")
1414
self$runset <- runset
15+
private$model_methods_env_ <- runset$model_methods_env()
16+
17+
if (!is.null(private$model_methods_env_$model_ptr)) {
18+
initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
19+
}
1520
invisible(self)
1621
},
1722
num_procs = function() {
@@ -63,7 +68,8 @@ CmdStanFit <- R6::R6Class(
6368
draws_ = NULL,
6469
metadata_ = NULL,
6570
init_ = NULL,
66-
profiles_ = NULL
71+
profiles_ = NULL,
72+
model_methods_env_ = NULL
6773
)
6874
)
6975

@@ -272,6 +278,208 @@ init <- function() {
272278
}
273279
CmdStanFit$set("public", name = "init", value = init)
274280

281+
#' Compile additional methods for accessing the model log-probability function
282+
#' and parameter constraining and unconstraining. This requires the `Rcpp` package.
283+
#'
284+
#' @name fit-method-init_model_methods
285+
#' @aliases init_model_methods
286+
#' @description The `$init_model_methods()` compiles and initializes the
287+
#' `log_prob`, `grad_log_prob`, `constrain_pars`, and `unconstrain_pars` functions.
288+
#'
289+
#' @param seed (integer) The random seed to use when initializing the model.
290+
#' @param verbose (boolean) Whether to show verbose logging during compilation.
291+
#' @param hessian (boolean) Whether to expose the (experimental) hessian method.
292+
#'
293+
#' @examples
294+
#' \dontrun{
295+
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
296+
#' fit_mcmc$init_model_methods()
297+
#' }
298+
#'
299+
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
300+
require_suggested_package("Rcpp")
301+
require_suggested_package("RcppEigen")
302+
if (length(private$model_methods_env_$hpp_code_) == 0) {
303+
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
304+
"the model must be compiled again", call. = FALSE)
305+
}
306+
if (hessian) {
307+
message("The hessian method relies on higher-order autodiff ",
308+
"which is still experimental. Please report any compilation ",
309+
"errors that you encounter",
310+
call. = FALSE)
311+
}
312+
message("Compiling additional model methods...")
313+
if (is.null(private$model_methods_env_$model_ptr)) {
314+
expose_model_methods(private$model_methods_env_, verbose, hessian)
315+
}
316+
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
317+
invisible(NULL)
318+
}
319+
CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods)
320+
321+
#' Calculate the log-probability given a provided vector of unconstrained parameters.
322+
#'
323+
#' @name fit-method-log_prob
324+
#' @aliases log_prob
325+
#' @description The `$log_prob()` method provides access to the Stan model's `log_prob` function
326+
#'
327+
#' @param upars (numeric) A vector of unconstrained parameters to be passed to `log_prob`
328+
#' @param jacobian_adjustment (bool) Whether to include the log-density adjustments from
329+
#' un/constraining variables
330+
#'
331+
#' @examples
332+
#' \dontrun{
333+
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
334+
#' fit_mcmc$log_prob(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
335+
#' }
336+
#'
337+
log_prob <- function(upars, jacobian_adjustment = TRUE) {
338+
if (is.null(private$model_methods_env_$model_ptr)) {
339+
stop("The method has not been compiled, please call `init_model_methods()` first",
340+
call. = FALSE)
341+
}
342+
if (length(upars) != private$model_methods_env_$num_upars_) {
343+
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
344+
length(upars), " were provided!", call. = FALSE)
345+
}
346+
private$model_methods_env_$log_prob(private$model_methods_env_$model_ptr_, upars, jacobian_adjustment)
347+
}
348+
CmdStanFit$set("public", name = "log_prob", value = log_prob)
349+
350+
#' Calculate the log-probability and the gradient w.r.t. each input for a
351+
#' given vector of unconstrained parameters
352+
#'
353+
#' @name fit-method-grad_log_prob
354+
#' @aliases grad_log_prob
355+
#' @description The `$grad_log_prob()` method provides access to the
356+
#' Stan model's `log_prob` function and its derivative
357+
#'
358+
#' @param upars (numeric) A vector of unconstrained parameters to be passed
359+
#' to `grad_log_prob`
360+
#' @param jacobian_adjustment (bool) Whether to include the log-density adjustments from
361+
#' un/constraining variables
362+
#'
363+
#' @examples
364+
#' \dontrun{
365+
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
366+
#' fit_mcmc$grad_log_prob(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
367+
#' }
368+
#'
369+
grad_log_prob <- function(upars, jacobian_adjustment = TRUE) {
370+
if (is.null(private$model_methods_env_$model_ptr)) {
371+
stop("The method has not been compiled, please call `init_model_methods()` first",
372+
call. = FALSE)
373+
}
374+
if (length(upars) != private$model_methods_env_$num_upars_) {
375+
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
376+
length(upars), " were provided!", call. = FALSE)
377+
}
378+
private$model_methods_env_$grad_log_prob(private$model_methods_env_$model_ptr_, upars, jacobian_adjustment)
379+
}
380+
CmdStanFit$set("public", name = "grad_log_prob", value = grad_log_prob)
381+
382+
#' Calculate the log-probability , the gradient w.r.t. each input, and the hessian
383+
#' for a given vector of unconstrained parameters
384+
#'
385+
#' @name fit-method-hessian
386+
#' @aliases hessian
387+
#' @description The `$hessian()` method provides access to the
388+
#' Stan model's `log_prob`, its derivative, and its hessian
389+
#'
390+
#' @param upars (numeric) A vector of unconstrained parameters to be passed
391+
#' to `hessian`
392+
#'
393+
#' @examples
394+
#' \dontrun{
395+
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
396+
#' fit_mcmc$hessian(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
397+
#' }
398+
#'
399+
hessian <- function(upars) {
400+
if (is.null(private$model_methods_env_$model_ptr)) {
401+
stop("The method has not been compiled, please call `init_model_methods()` first",
402+
call. = FALSE)
403+
}
404+
if (length(upars) != private$model_methods_env_$num_upars_) {
405+
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
406+
length(upars), " were provided!", call. = FALSE)
407+
}
408+
private$model_methods_env_$hessian(private$model_methods_env_$model_ptr_, upars)
409+
}
410+
CmdStanFit$set("public", name = "hessian", value = hessian)
411+
412+
#' Transform a set of parameter values to the unconstrained scale
413+
#'
414+
#' @name fit-method-unconstrain_pars
415+
#' @aliases unconstrain_pars
416+
#' @description The `$unconstrain_pars()` method transforms input parameters to
417+
#' the unconstrained scale
418+
#'
419+
#' @param pars (list) A list of parameter values to transform, in the same format as
420+
#' provided to the `init` argument of the `$sample()` method
421+
#'
422+
#' @examples
423+
#' \dontrun{
424+
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
425+
#' fit_mcmc$unconstrain_pars(list(alpha = 0.5, beta = c(0.7, 1.1, 0.2)))
426+
#' }
427+
#'
428+
unconstrain_pars <- function(pars) {
429+
if (is.null(private$model_methods_env_$model_ptr)) {
430+
stop("The method has not been compiled, please call `init_model_methods()` first",
431+
call. = FALSE)
432+
}
433+
model_par_names <- names(self$runset$args$model_variables$parameters)
434+
prov_par_names <- names(pars)
435+
436+
prov_pars_not_in_model <- which(!(prov_par_names %in% model_par_names))
437+
if (length(prov_pars_not_in_model) > 0) {
438+
stop("Provided parameter(s): ", paste(prov_par_names[prov_pars_not_in_model], collapse = ","),
439+
" not present in model!", call. = FALSE)
440+
}
441+
442+
model_pars_not_prov <- which(!(model_par_names %in% prov_par_names))
443+
if (length(model_pars_not_prov) > 0) {
444+
stop("Model parameter(s): ", paste(model_par_names[model_pars_not_prov], collapse = ","),
445+
" not provided!", call. = FALSE)
446+
}
447+
448+
stan_pars <- process_init_list(list(pars), num_procs = 1, self$runset$args$model_variables)
449+
private$model_methods_env_$unconstrain_pars(private$model_methods_env_$model_ptr_, stan_pars)
450+
}
451+
CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars)
452+
453+
#' Transform a set of unconstrained parameter values to the constrained scale
454+
#'
455+
#' @name fit-method-constrain_pars
456+
#' @aliases constrain_pars
457+
#' @description The `$constrain_pars()` method transforms input parameters to
458+
#' the constrained scale
459+
#'
460+
#' @param upars (numeric) A vector of unconstrained parameters to constrain
461+
#'
462+
#' @examples
463+
#' \dontrun{
464+
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
465+
#' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
466+
#' }
467+
#'
468+
constrain_pars <- function(upars) {
469+
if (is.null(private$model_methods_env_$model_ptr)) {
470+
stop("The method has not been compiled, please call `init_model_methods()` first",
471+
call. = FALSE)
472+
}
473+
if (length(upars) != private$model_methods_env_$num_upars_) {
474+
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
475+
length(upars), " were provided!", call. = FALSE)
476+
}
477+
cpars <- private$model_methods_env_$constrain_pars(private$model_methods_env_$model_ptr_, private$model_methods_env_$model_rng_, upars)
478+
skeleton <- create_skeleton(self$runset$args$model_variables)
479+
utils::relist(cpars, skeleton)
480+
}
481+
CmdStanFit$set("public", name = "constrain_pars", value = constrain_pars)
482+
275483
#' Extract log probability (target)
276484
#'
277485
#' @name fit-method-lp

R/model.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ CmdStanModel <- R6::R6Class(
210210
model_name_ = character(),
211211
exe_file_ = character(),
212212
hpp_file_ = character(),
213+
model_methods_env_ = NULL,
213214
dir_ = NULL,
214215
cpp_options_ = list(),
215216
stanc_options_ = list(),
@@ -382,6 +383,10 @@ CmdStanModel <- R6::R6Class(
382383
#' @param force_recompile (logical) Should the model be recompiled even if was
383384
#' not modified since last compiled. The default is `FALSE`. Can also be set
384385
#' via a global `cmdstanr_force_recompile` option.
386+
#' @param compile_model_methods (logical) Compile additional model methods
387+
#' (`log_prob()`, `grad_log_prob()`, `constrain_pars()`, `unconstrain_pars()`)
388+
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
389+
#' be compiled with the model methods?
385390
#' @param threads Deprecated and will be removed in a future release. Please
386391
#' turn on threading via `cpp_options = list(stan_threads = TRUE)` instead.
387392
#'
@@ -431,6 +436,8 @@ compile <- function(quiet = TRUE,
431436
cpp_options = list(),
432437
stanc_options = list(),
433438
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
439+
compile_model_methods = FALSE,
440+
compile_hessian_method = FALSE,
434441
#deprecated
435442
threads = FALSE) {
436443
if (length(self$stan_file()) == 0) {
@@ -613,6 +620,13 @@ compile <- function(quiet = TRUE,
613620
private$precompile_cpp_options_ <- NULL
614621
private$precompile_stanc_options_ <- NULL
615622
private$precompile_include_paths_ <- NULL
623+
private$model_methods_env_ <- new.env()
624+
private$model_methods_env_$hpp_code_ <- readLines(private$hpp_file_)
625+
if (compile_model_methods) {
626+
expose_model_methods(env = private$model_methods_env_,
627+
verbose = !quiet,
628+
hessian = compile_hessian_method)
629+
}
616630
invisible(self)
617631
}
618632
CmdStanModel$set("public", name = "compile", value = compile)
@@ -1102,6 +1116,7 @@ sample <- function(data = NULL,
11021116
method_args = sample_args,
11031117
stan_file = self$stan_file(),
11041118
stan_code = suppressWarnings(self$code()),
1119+
model_methods_env = private$model_methods_env_,
11051120
model_name = self$model_name(),
11061121
exe_file = self$exe_file(),
11071122
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
@@ -1259,6 +1274,7 @@ sample_mpi <- function(data = NULL,
12591274
method_args = sample_args,
12601275
stan_file = self$stan_file(),
12611276
stan_code = suppressWarnings(self$code()),
1277+
model_methods_env = private$model_methods_env_,
12621278
model_name = self$model_name(),
12631279
exe_file = self$exe_file(),
12641280
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
@@ -1370,6 +1386,7 @@ optimize <- function(data = NULL,
13701386
method_args = optimize_args,
13711387
stan_file = self$stan_file(),
13721388
stan_code = suppressWarnings(self$code()),
1389+
model_methods_env = private$model_methods_env_,
13731390
model_name = self$model_name(),
13741391
exe_file = self$exe_file(),
13751392
proc_ids = 1,
@@ -1487,6 +1504,7 @@ variational <- function(data = NULL,
14871504
method_args = variational_args,
14881505
stan_file = self$stan_file(),
14891506
stan_code = suppressWarnings(self$code()),
1507+
model_methods_env = private$model_methods_env_,
14901508
model_name = self$model_name(),
14911509
exe_file = self$exe_file(),
14921510
proc_ids = 1,
@@ -1603,6 +1621,7 @@ generate_quantities <- function(fitted_params,
16031621
method_args = gq_args,
16041622
stan_file = self$stan_file(),
16051623
stan_code = suppressWarnings(self$code()),
1624+
model_methods_env = private$model_methods_env_,
16061625
model_name = self$model_name(),
16071626
exe_file = self$exe_file(),
16081627
proc_ids = seq_along(fitted_params_files),
@@ -1666,6 +1685,7 @@ diagnose <- function(data = NULL,
16661685
method_args = diagnose_args,
16671686
stan_file = self$stan_file(),
16681687
stan_code = suppressWarnings(self$code()),
1688+
model_methods_env = private$model_methods_env_,
16691689
model_name = self$model_name(),
16701690
exe_file = self$exe_file(),
16711691
proc_ids = 1,

R/run.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ CmdStanRun <- R6::R6Class(
3535
proc_ids = function() self$procs$proc_ids(),
3636
exe_file = function() self$args$exe_file,
3737
stan_code = function() self$args$stan_code,
38+
model_methods_env = function() self$args$model_methods_env,
3839
model_name = function() self$args$model_name,
3940
method = function() self$args$method,
4041
data_file = function() self$args$data_file,

0 commit comments

Comments
 (0)