Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org
BugReports: https://github.com/stan-dev/cmdstanr/issues
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE, r6 = FALSE)
SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan)
Depends:
Expand Down
38 changes: 33 additions & 5 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ CmdStanArgs <- R6::R6Class(
sig_figs = NULL,
opencl_ids = NULL,
model_variables = NULL,
num_threads = NULL) {
num_threads = NULL,
save_cmdstan_config = NULL) {

self$model_name <- model_name
self$stan_code <- stan_code
Expand All @@ -60,6 +61,7 @@ CmdStanArgs <- R6::R6Class(
self$save_latent_dynamics <- save_latent_dynamics
self$using_tempdir <- is.null(output_dir)
self$model_variables <- model_variables
self$save_cmdstan_config <- save_cmdstan_config
if (os_is_wsl()) {
# Want to ensure that any files under WSL are written to a tempdir within
# WSL to avoid IO performance issues
Expand All @@ -86,6 +88,9 @@ CmdStanArgs <- R6::R6Class(
self$opencl_ids <- opencl_ids
self$num_threads = NULL
self$method_args$validate(num_procs = length(self$proc_ids))
if (is.logical(self$save_cmdstan_config)) {
self$save_cmdstan_config <- as.integer(self$save_cmdstan_config)
}
self$validate()
},
validate = function() {
Expand All @@ -110,9 +115,10 @@ CmdStanArgs <- R6::R6Class(
} else if (type == "profile") {
basename <- paste0(basename, "-profile")
}
if (type == "output" && !is.null(self$output_basename)) {
if (type == "output" && !is.null(self$output_basename)) {
basename <- self$output_basename
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you undo these whitespace changes? There are a few throughout the diff

generate_file_names(
basename = basename,
ext = ".csv",
Expand Down Expand Up @@ -179,12 +185,16 @@ CmdStanArgs <- R6::R6Class(
if (!is.null(profile_file)) {
args$output <- c(args$output, paste0("profile_file=", wsl_safe_path(profile_file)))
}
if (!is.null(self$save_cmdstan_config)) {
args$output <- c(args$output, paste0("save_cmdstan_config=", self$save_cmdstan_config))
}
if (!is.null(self$opencl_ids)) {
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
}
if (!is.null(self$num_threads)) {
num_threads <- c(args$output, paste0("num_threads=", self$num_threads))
}

args <- do.call(c, append(args, list(use.names = FALSE)))
self$method_args$compose(idx, args)
},
Expand Down Expand Up @@ -217,7 +227,8 @@ SampleArgs <- R6::R6Class(
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
diagnostics = NULL) {
diagnostics = NULL,
save_metric = NULL) {

self$iter_warmup <- iter_warmup
self$iter_sampling <- iter_sampling
Expand All @@ -231,6 +242,7 @@ SampleArgs <- R6::R6Class(
self$inv_metric <- inv_metric
self$fixed_param <- fixed_param
self$diagnostics <- diagnostics
self$save_metric <- save_metric
if (identical(self$diagnostics, "")) {
self$diagnostics <- NULL
}
Expand Down Expand Up @@ -274,6 +286,9 @@ SampleArgs <- R6::R6Class(
if (is.logical(self$save_warmup)) {
self$save_warmup <- as.integer(self$save_warmup)
}
if (is.logical(self$save_metric)) {
self$save_metric <- as.integer(self$save_metric)
}
invisible(self)
},
validate = function(num_procs) {
Expand Down Expand Up @@ -313,7 +328,8 @@ SampleArgs <- R6::R6Class(
.make_arg("adapt_engaged"),
.make_arg("init_buffer"),
.make_arg("term_buffer"),
.make_arg("window")
.make_arg("window"),
.make_arg("save_metric")
)
} else {
new_args <- list(
Expand All @@ -334,7 +350,8 @@ SampleArgs <- R6::R6Class(
.make_arg("adapt_engaged"),
.make_arg("init_buffer"),
.make_arg("term_buffer"),
.make_arg("window")
.make_arg("window"),
.make_arg("save_metric")
)
}
new_args <- do.call(c, new_args)
Expand Down Expand Up @@ -681,6 +698,7 @@ validate_cmdstan_args <- function(self) {
checkmate::assert_flag(self$save_latent_dynamics)
checkmate::assert_integerish(self$refresh, lower = 0, null.ok = TRUE)
checkmate::assert_integerish(self$sig_figs, lower = 1, upper = 18, null.ok = TRUE)
checkmate::assert_integerish(self$save_cmdstan_config, lower = 0, upper = 1, len = 1, null.ok = TRUE)
if (!is.null(self$sig_figs) && cmdstan_version() < "2.25") {
warning("The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!", call. = FALSE)
}
Expand All @@ -690,6 +708,7 @@ validate_cmdstan_args <- function(self) {
if (!is.null(self$data_file)) {
assert_file_exists(self$data_file, access = "r")
}

num_procs <- length(self$proc_ids)
validate_init(self$init, num_procs)
validate_seed(self$seed, num_procs)
Expand Down Expand Up @@ -793,6 +812,15 @@ validate_sample_args <- function(self, num_procs) {
checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_hmc_diagnostics())
}

checkmate::assert_integerish(self$save_metric,
lower = 0, upper = 1,
len = 1,
null.ok = TRUE)

if (is.null(self$adapt_engaged) || (!self$adapt_engaged && !is.null(self$save_metric))) {
self$save_metric <- 0
}

invisible(TRUE)
}

Expand Down
50 changes: 46 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -898,10 +898,13 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
#' Save output and data files
#'
#' @name fit-method-save_output_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files fit-method-save_profile_files
#' fit-method-output_files fit-method-data_file fit-method-latent_dynamics_files fit-method-profile_files
#' save_output_files save_data_file save_latent_dynamics_files save_profile_files
#' output_files data_file latent_dynamics_files profile_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files
#' fit-method-save_profile_files fit-method-output_files fit-method-data_file
#' fit-method-latent_dynamics_files fit-method-profile_files
#' fit-method-save_config_files fit-method-save_metric_files save_output_files
#' save_data_file save_latent_dynamics_files save_profile_files
#' save_config_files save_metric_files output_files data_file
#' latent_dynamics_files profile_files config_files metric_files
#'
#' @description All fitted model objects have methods for saving (moving to a
#' specified location) the files created by CmdStanR to hold CmdStan output
Expand Down Expand Up @@ -936,6 +939,14 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
#' `$save_output_files()` except `"-profile-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_metric_files()` everything is the same as for
#' `$save_output_files()` except `"-metric-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_config_files()` everything is the same as for
#' `$save_output_files()` except `"-config-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_data_file()` no `id` is included in the file name because even
#' with multiple MCMC chains the data file is the same.
#'
Expand Down Expand Up @@ -998,6 +1009,26 @@ save_data_file <- function(dir = ".",
}
CmdStanFit$set("public", name = "save_data_file", value = save_data_file)

#' @rdname fit-method-save_output_files
save_config_files <- function(dir = ".",
basename = NULL,
timestamp = TRUE,
random = TRUE) {
self$runset$save_config_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_config_files", value = save_config_files)

#' @rdname fit-method-save_output_files
save_metric_files <- function(dir = ".",
basename = NULL,
timestamp = TRUE,
random = TRUE) {
self$runset$save_metric_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_metric_files", value = save_metric_files)



#' @rdname fit-method-save_output_files
#' @param include_failed (logical) Should CmdStan runs that failed also be
#' included? The default is `FALSE.`
Expand All @@ -1024,6 +1055,17 @@ data_file <- function() {
}
CmdStanFit$set("public", name = "data_file", value = data_file)

#' @rdname fit-method-save_output_files
config_files <- function(include_failed = FALSE) {
self$runset$config_files(include_failed)
}
CmdStanFit$set("public", name = "config_files", value = config_files)

#' @rdname fit-method-save_output_files
metric_files <- function(include_failed = FALSE) {
self$runset$metric_files(include_failed)
}
CmdStanFit$set("public", name = "metric_files", value = metric_files)

#' Report timing of CmdStan runs
#'
Expand Down
36 changes: 25 additions & 11 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,8 @@ sample <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_metric = TRUE,
save_cmdstan_config = TRUE,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
save_metric = TRUE,
save_cmdstan_config = TRUE,
save_metric = if (cmdstan_version() > "2.34.0") TRUE else NULL,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") TRUE else NULL,

Then the function arguments below would just use the save_metric and save_cmdstan_config values directly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the same for the other methods as well (laplace, optimize, etc.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, missed that when I was looking at it

# deprecated
cores = NULL,
num_cores = NULL,
Expand Down Expand Up @@ -1240,7 +1242,8 @@ sample <- function(data = NULL,
term_buffer = term_buffer,
window = window,
fixed_param = fixed_param,
diagnostics = diagnostics
diagnostics = diagnostics,
save_metric = if (cmdstan_version() > "2.34.0") save_metric else NULL
)
args <- CmdStanArgs$new(
method_args = sample_args,
Expand All @@ -1260,7 +1263,8 @@ sample <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1357,6 +1361,7 @@ sample_mpi <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_cmdstan_config = TRUE,
# deprecated
validate_csv = TRUE) {

Expand Down Expand Up @@ -1420,7 +1425,8 @@ sample_mpi <- function(data = NULL,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan_mpi(mpi_cmd, mpi_args)
Expand Down Expand Up @@ -1500,7 +1506,8 @@ optimize <- function(data = NULL,
tol_param = NULL,
history_size = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = TRUE) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1541,7 +1548,8 @@ optimize <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1632,7 +1640,8 @@ laplace <- function(data = NULL,
jacobian = TRUE, # different default than for optimize!
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = TRUE) {
if (cmdstan_version() < "2.32") {
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
}
Expand Down Expand Up @@ -1706,7 +1715,8 @@ laplace <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1786,7 +1796,8 @@ variational <- function(data = NULL,
output_samples = NULL,
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = TRUE) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1827,7 +1838,8 @@ variational <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1929,7 +1941,8 @@ pathfinder <- function(data = NULL,
psis_resample = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = TRUE) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1976,7 +1989,8 @@ pathfinder <- function(data = NULL,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables,
num_threads = num_threads
num_threads = num_threads,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down
Loading