Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ sample <- function(data = NULL,
window = NULL,
fixed_param = FALSE,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
# deprecated
cores = NULL,
Expand Down Expand Up @@ -1123,7 +1124,8 @@ sample <- function(data = NULL,
num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1),
parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE),
threads_per_proc = assert_valid_threads(threads_per_chain, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_messages
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
model_variables <- NULL
if (is_variables_method_supported(self)) {
Expand Down Expand Up @@ -1260,6 +1262,7 @@ sample_mpi <- function(data = NULL,
fixed_param = FALSE,
sig_figs = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
# deprecated
validate_csv = TRUE) {
Expand All @@ -1282,7 +1285,8 @@ sample_mpi <- function(data = NULL,
procs <- CmdStanMCMCProcs$new(
num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1),
parallel_procs = 1,
show_stderr_messages = show_messages
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
model_variables <- NULL
if (is_variables_method_supported(self)) {
Expand Down
36 changes: 29 additions & 7 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,13 @@ check_target_exe <- function(exe) {
}
}
if (is.null(procs$threads_per_proc())) {
cat(paste0(start_msg, "...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, "...\n\n"))
}
} else {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
}
Sys.setenv("STAN_NUM_THREADS" = as.integer(procs$threads_per_proc()))
# Windows environment variables have to be explicitly exported to WSL
if (os_is_wsl()) {
Expand Down Expand Up @@ -425,9 +429,13 @@ CmdStanRun$set("private", name = "run_sample_", value = .run_sample)
}
}
if (is.null(procs$threads_per_proc())) {
cat(paste0(start_msg, "...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, "...\n\n"))
}
} else {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
}
Sys.setenv("STAN_NUM_THREADS" = as.integer(procs$threads_per_proc()))
# Windows environment variables have to be explicitly exported to WSL
if (os_is_wsl()) {
Expand Down Expand Up @@ -612,6 +620,12 @@ CmdStanProcs <- R6::R6Class(
private$show_stdout_messages_ <- show_stdout_messages
invisible(self)
},
show_stdout_messages = function () {
private$show_stdout_messages_
},
show_stderr_messages = function () {
private$show_stderr_messages_
},
num_procs = function() {
private$num_procs_
},
Expand Down Expand Up @@ -927,7 +941,7 @@ CmdStanMCMCProcs <- R6::R6Class(
|| grepl("stancflags", line, fixed = TRUE)) {
ignore_line <- TRUE
}
if ((state > 1.5 && state < 5 && !ignore_line) || is_verbose_mode()) {
if ((state > 1.5 && state < 5 && !ignore_line && private$show_stdout_messages_) || is_verbose_mode()) {
if (state == 2) {
message("Chain ", id, " ", line)
} else {
Expand All @@ -939,7 +953,9 @@ CmdStanMCMCProcs <- R6::R6Class(
if (state == 1) {
state <- 2;
}
message("Chain ", id, " ", line)
if (private$show_stderr_messages_) {
message("Chain ", id, " ", line)
}
}
private$proc_state_[[id]] <- next_state
} else {
Expand All @@ -951,6 +967,9 @@ CmdStanMCMCProcs <- R6::R6Class(
invisible(self)
},
report_time = function(id = NULL) {
if (!private$show_stdout_messages_) {
return(invisible(NULL))
}
if (!is.null(id)) {
if (self$proc_state(id) == 7) {
warning("Chain ", id, " finished unexpectedly!\n", immediate. = TRUE, call. = FALSE)
Expand Down Expand Up @@ -1030,7 +1049,7 @@ CmdStanGQProcs <- R6::R6Class(
if (nzchar(line)) {
if (self$proc_state(id) == 1 && grepl("refresh = ", line, perl = TRUE)) {
self$set_proc_state(id, new_state = 1.5)
} else if (self$proc_state(id) >= 2) {
} else if (self$proc_state(id) >= 2 && private$show_stdout_messages_) {
cat("Chain", id, line, "\n")
}
} else {
Expand All @@ -1044,6 +1063,9 @@ CmdStanGQProcs <- R6::R6Class(
invisible(self)
},
report_time = function(id = NULL) {
if (!private$show_stdout_messages_) {
return(invisible(NULL))
}
if (!is.null(id)) {
if (self$proc_state(id) == 7) {
warning("Chain ", id, " finished unexpectedly!\n", immediate. = TRUE, call. = FALSE)
Expand Down
4 changes: 4 additions & 0 deletions man-roxygen/model-sample-args.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
#' `fixed_param=TRUE` is mandatory. When `fixed_param=TRUE` the `chains` and
#' `parallel_chains` arguments will be set to `1`.
#' @param show_messages (logical) When `TRUE` (the default), prints all
#' output during the sampling process, such as iteration numbers and elapsed times.
#' If the output is silenced then the [`$output()`][fit-method-output] method of
#' the resulting fit object can be used to display the silenced messages.
#' @param show_exceptions (logical) When `TRUE` (the default), prints all
#' informational messages, for example rejection of the current proposal.
#' Disable if you wish to silence these messages, but this is not usually
#' recommended unless you are very confident that the model is correct up to
Expand Down
6 changes: 6 additions & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/model-method-sample_mpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 43 additions & 2 deletions tests/testthat/test-model-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ test_that("sample() method runs when the stan file is removed", {
)
})

test_that("sample() prints informational messages depening on show_messages", {
test_that("sample() prints informational messages depening on show_exceptions", {
mod_info_msg <- testing_model("info_message")
expect_sample_output(
expect_message(
Expand All @@ -127,7 +127,7 @@ test_that("sample() prints informational messages depening on show_messages", {
)
)
expect_sample_output(
expect_message(mod_info_msg$sample(show_messages = FALSE), regexp = NA)
expect_message(mod_info_msg$sample(show_exceptions = FALSE), regexp = NA)
)
})

Expand Down Expand Up @@ -321,3 +321,44 @@ test_that("sig_figs warning if version less than 2.25", {
)
reset_cmdstan_version()
})

test_that("Errors are suppressed with show_exceptions", {
errmodcode <- "
data {
real y_mean;
}
transformed data {
vector[1] small;
small[2] = 1.0;
}
parameters {
real y;
}
model {
y ~ normal(y_mean, 1);
}
"
errmod <- cmdstan_model(write_stan_file(errmodcode), force_recompile = TRUE)

expect_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1)),
"Chain 1 Exception: vector[uni] assign: accessing element out of range",
fixed = TRUE
)

expect_no_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1, show_exceptions = FALSE))
)
})

test_that("All output can be suppressed by show_messages", {
stan_program <- testing_stan_file("bernoulli")
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(stan_program, force_recompile = TRUE)
options("cmdstanr_verbose" = FALSE)
output <- capture.output(
fit <- mod$sample(data = data_list, show_messages = FALSE)
)

expect_length(output, 0)
})