@@ -12,6 +12,11 @@ CmdStanFit <- R6::R6Class(
1212 initialize = function (runset ) {
1313 checkmate :: assert_r6(runset , classes = " CmdStanRun" )
1414 self $ runset <- runset
15+ private $ model_methods_env_ <- runset $ model_methods_env()
16+
17+ if (! is.null(private $ model_methods_env_ $ model_ptr )) {
18+ initialize_model_pointer(private $ model_methods_env_ , self $ data_file(), 0 )
19+ }
1520 invisible (self )
1621 },
1722 num_procs = function () {
@@ -63,7 +68,8 @@ CmdStanFit <- R6::R6Class(
6368 draws_ = NULL ,
6469 metadata_ = NULL ,
6570 init_ = NULL ,
66- profiles_ = NULL
71+ profiles_ = NULL ,
72+ model_methods_env_ = NULL
6773 )
6874)
6975
@@ -272,6 +278,208 @@ init <- function() {
272278}
273279CmdStanFit $ set(" public" , name = " init" , value = init )
274280
281+ # ' Compile additional methods for accessing the model log-probability function
282+ # ' and parameter constraining and unconstraining. This requires the `Rcpp` package.
283+ # '
284+ # ' @name fit-method-init_model_methods
285+ # ' @aliases init_model_methods
286+ # ' @description The `$init_model_methods()` compiles and initializes the
287+ # ' `log_prob`, `grad_log_prob`, `constrain_pars`, and `unconstrain_pars` functions.
288+ # '
289+ # ' @param seed (integer) The random seed to use when initializing the model.
290+ # ' @param verbose (boolean) Whether to show verbose logging during compilation.
291+ # ' @param hessian (boolean) Whether to expose the (experimental) hessian method.
292+ # '
293+ # ' @examples
294+ # ' \dontrun{
295+ # ' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
296+ # ' fit_mcmc$init_model_methods()
297+ # ' }
298+ # '
299+ init_model_methods <- function (seed = 0 , verbose = FALSE , hessian = FALSE ) {
300+ require_suggested_package(" Rcpp" )
301+ require_suggested_package(" RcppEigen" )
302+ if (length(private $ model_methods_env_ $ hpp_code_ ) == 0 ) {
303+ stop(" Model methods cannot be used with a pre-compiled Stan executable, " ,
304+ " the model must be compiled again" , call. = FALSE )
305+ }
306+ if (hessian ) {
307+ message(" The hessian method relies on higher-order autodiff " ,
308+ " which is still experimental. Please report any compilation " ,
309+ " errors that you encounter" ,
310+ call. = FALSE )
311+ }
312+ message(" Compiling additional model methods..." )
313+ if (is.null(private $ model_methods_env_ $ model_ptr )) {
314+ expose_model_methods(private $ model_methods_env_ , verbose , hessian )
315+ }
316+ initialize_model_pointer(private $ model_methods_env_ , self $ data_file(), seed )
317+ invisible (NULL )
318+ }
319+ CmdStanFit $ set(" public" , name = " init_model_methods" , value = init_model_methods )
320+
321+ # ' Calculate the log-probability given a provided vector of unconstrained parameters.
322+ # '
323+ # ' @name fit-method-log_prob
324+ # ' @aliases log_prob
325+ # ' @description The `$log_prob()` method provides access to the Stan model's `log_prob` function
326+ # '
327+ # ' @param upars (numeric) A vector of unconstrained parameters to be passed to `log_prob`
328+ # ' @param jacobian_adjustment (bool) Whether to include the log-density adjustments from
329+ # ' un/constraining variables
330+ # '
331+ # ' @examples
332+ # ' \dontrun{
333+ # ' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
334+ # ' fit_mcmc$log_prob(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
335+ # ' }
336+ # '
337+ log_prob <- function (upars , jacobian_adjustment = TRUE ) {
338+ if (is.null(private $ model_methods_env_ $ model_ptr )) {
339+ stop(" The method has not been compiled, please call `init_model_methods()` first" ,
340+ call. = FALSE )
341+ }
342+ if (length(upars ) != private $ model_methods_env_ $ num_upars_ ) {
343+ stop(" Model has " , private $ model_methods_env_ $ num_upars_ , " unconstrained parameter(s), but " ,
344+ length(upars ), " were provided!" , call. = FALSE )
345+ }
346+ private $ model_methods_env_ $ log_prob(private $ model_methods_env_ $ model_ptr_ , upars , jacobian_adjustment )
347+ }
348+ CmdStanFit $ set(" public" , name = " log_prob" , value = log_prob )
349+
350+ # ' Calculate the log-probability and the gradient w.r.t. each input for a
351+ # ' given vector of unconstrained parameters
352+ # '
353+ # ' @name fit-method-grad_log_prob
354+ # ' @aliases grad_log_prob
355+ # ' @description The `$grad_log_prob()` method provides access to the
356+ # ' Stan model's `log_prob` function and its derivative
357+ # '
358+ # ' @param upars (numeric) A vector of unconstrained parameters to be passed
359+ # ' to `grad_log_prob`
360+ # ' @param jacobian_adjustment (bool) Whether to include the log-density adjustments from
361+ # ' un/constraining variables
362+ # '
363+ # ' @examples
364+ # ' \dontrun{
365+ # ' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
366+ # ' fit_mcmc$grad_log_prob(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
367+ # ' }
368+ # '
369+ grad_log_prob <- function (upars , jacobian_adjustment = TRUE ) {
370+ if (is.null(private $ model_methods_env_ $ model_ptr )) {
371+ stop(" The method has not been compiled, please call `init_model_methods()` first" ,
372+ call. = FALSE )
373+ }
374+ if (length(upars ) != private $ model_methods_env_ $ num_upars_ ) {
375+ stop(" Model has " , private $ model_methods_env_ $ num_upars_ , " unconstrained parameter(s), but " ,
376+ length(upars ), " were provided!" , call. = FALSE )
377+ }
378+ private $ model_methods_env_ $ grad_log_prob(private $ model_methods_env_ $ model_ptr_ , upars , jacobian_adjustment )
379+ }
380+ CmdStanFit $ set(" public" , name = " grad_log_prob" , value = grad_log_prob )
381+
382+ # ' Calculate the log-probability , the gradient w.r.t. each input, and the hessian
383+ # ' for a given vector of unconstrained parameters
384+ # '
385+ # ' @name fit-method-hessian
386+ # ' @aliases hessian
387+ # ' @description The `$hessian()` method provides access to the
388+ # ' Stan model's `log_prob`, its derivative, and its hessian
389+ # '
390+ # ' @param upars (numeric) A vector of unconstrained parameters to be passed
391+ # ' to `hessian`
392+ # '
393+ # ' @examples
394+ # ' \dontrun{
395+ # ' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
396+ # ' fit_mcmc$hessian(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
397+ # ' }
398+ # '
399+ hessian <- function (upars ) {
400+ if (is.null(private $ model_methods_env_ $ model_ptr )) {
401+ stop(" The method has not been compiled, please call `init_model_methods()` first" ,
402+ call. = FALSE )
403+ }
404+ if (length(upars ) != private $ model_methods_env_ $ num_upars_ ) {
405+ stop(" Model has " , private $ model_methods_env_ $ num_upars_ , " unconstrained parameter(s), but " ,
406+ length(upars ), " were provided!" , call. = FALSE )
407+ }
408+ private $ model_methods_env_ $ hessian(private $ model_methods_env_ $ model_ptr_ , upars )
409+ }
410+ CmdStanFit $ set(" public" , name = " hessian" , value = hessian )
411+
412+ # ' Transform a set of parameter values to the unconstrained scale
413+ # '
414+ # ' @name fit-method-unconstrain_pars
415+ # ' @aliases unconstrain_pars
416+ # ' @description The `$unconstrain_pars()` method transforms input parameters to
417+ # ' the unconstrained scale
418+ # '
419+ # ' @param pars (list) A list of parameter values to transform, in the same format as
420+ # ' provided to the `init` argument of the `$sample()` method
421+ # '
422+ # ' @examples
423+ # ' \dontrun{
424+ # ' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
425+ # ' fit_mcmc$unconstrain_pars(list(alpha = 0.5, beta = c(0.7, 1.1, 0.2)))
426+ # ' }
427+ # '
428+ unconstrain_pars <- function (pars ) {
429+ if (is.null(private $ model_methods_env_ $ model_ptr )) {
430+ stop(" The method has not been compiled, please call `init_model_methods()` first" ,
431+ call. = FALSE )
432+ }
433+ model_par_names <- names(self $ runset $ args $ model_variables $ parameters )
434+ prov_par_names <- names(pars )
435+
436+ prov_pars_not_in_model <- which(! (prov_par_names %in% model_par_names ))
437+ if (length(prov_pars_not_in_model ) > 0 ) {
438+ stop(" Provided parameter(s): " , paste(prov_par_names [prov_pars_not_in_model ], collapse = " ," ),
439+ " not present in model!" , call. = FALSE )
440+ }
441+
442+ model_pars_not_prov <- which(! (model_par_names %in% prov_par_names ))
443+ if (length(model_pars_not_prov ) > 0 ) {
444+ stop(" Model parameter(s): " , paste(model_par_names [model_pars_not_prov ], collapse = " ," ),
445+ " not provided!" , call. = FALSE )
446+ }
447+
448+ stan_pars <- process_init_list(list (pars ), num_procs = 1 , self $ runset $ args $ model_variables )
449+ private $ model_methods_env_ $ unconstrain_pars(private $ model_methods_env_ $ model_ptr_ , stan_pars )
450+ }
451+ CmdStanFit $ set(" public" , name = " unconstrain_pars" , value = unconstrain_pars )
452+
453+ # ' Transform a set of unconstrained parameter values to the constrained scale
454+ # '
455+ # ' @name fit-method-constrain_pars
456+ # ' @aliases constrain_pars
457+ # ' @description The `$constrain_pars()` method transforms input parameters to
458+ # ' the constrained scale
459+ # '
460+ # ' @param upars (numeric) A vector of unconstrained parameters to constrain
461+ # '
462+ # ' @examples
463+ # ' \dontrun{
464+ # ' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
465+ # ' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
466+ # ' }
467+ # '
468+ constrain_pars <- function (upars ) {
469+ if (is.null(private $ model_methods_env_ $ model_ptr )) {
470+ stop(" The method has not been compiled, please call `init_model_methods()` first" ,
471+ call. = FALSE )
472+ }
473+ if (length(upars ) != private $ model_methods_env_ $ num_upars_ ) {
474+ stop(" Model has " , private $ model_methods_env_ $ num_upars_ , " unconstrained parameter(s), but " ,
475+ length(upars ), " were provided!" , call. = FALSE )
476+ }
477+ cpars <- private $ model_methods_env_ $ constrain_pars(private $ model_methods_env_ $ model_ptr_ , private $ model_methods_env_ $ model_rng_ , upars )
478+ skeleton <- create_skeleton(self $ runset $ args $ model_variables )
479+ utils :: relist(cpars , skeleton )
480+ }
481+ CmdStanFit $ set(" public" , name = " constrain_pars" , value = constrain_pars )
482+
275483# ' Extract log probability (target)
276484# '
277485# ' @name fit-method-lp
0 commit comments