Skip to content

Commit 2b54e4c

Browse files
committed
Fix unconstraining variables with zero-length containers
1 parent 91da53c commit 2b54e4c

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

R/fit.R

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ CmdStanFit$set("public", name = "init", value = init)
306306
#' @name fit-method-init_model_methods
307307
#' @aliases init_model_methods
308308
#' @description The `$init_model_methods()` compiles and initializes the
309-
#' `log_prob`, `grad_log_prob`, `constrain_pars`, and `unconstrain_pars` functions.
309+
#' `log_prob`, `grad_log_prob`, `constrain_variables`, and `unconstrain_variables` functions.
310310
#'
311311
#' @param seed (integer) The random seed to use when initializing the model.
312312
#' @param verbose (boolean) Whether to show verbose logging during compilation.
@@ -465,7 +465,7 @@ unconstrain_variables <- function(variables) {
465465
stop("The method has not been compiled, please call `init_model_methods()` first",
466466
call. = FALSE)
467467
}
468-
model_par_names <- names(self$runset$args$model_variables$parameters)
468+
model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
469469
prov_par_names <- names(variables)
470470

471471
model_pars_not_prov <- which(!(model_par_names %in% prov_par_names))
@@ -477,7 +477,17 @@ unconstrain_variables <- function(variables) {
477477
# Ignore extraneous parameters
478478
model_pars_only <- variables[model_par_names]
479479

480-
stan_pars <- process_init_list(list(variables), num_procs = 1, self$runset$args$model_variables)
480+
model_variables <- self$runset$args$model_variables
481+
482+
# If zero-length parameters are present, they will be listed in model_variables
483+
# but not in metadata()$variables
484+
nonzero_length_params <- names(model_variables$parameters) %in% model_par_names
485+
486+
# Remove zero-length parameters from model_variables, otherwise process_init_list
487+
# warns about missing inputs
488+
model_variables$parameters <- model_variables$parameters[nonzero_length_params]
489+
490+
stan_pars <- process_init_list(list(variables), num_procs = 1, model_variables)
481491
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars)
482492
}
483493
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)

tests/testthat/test-model-methods.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,25 @@ test_that("Methods can be compiled with model", {
179179
unconstrained_variables <- fit$unconstrain_variables(cpars)
180180
expect_equal(unconstrained_variables, c(0.6))
181181
})
182+
183+
test_that("unconstrain_variables correctly handles zero-length containers", {
184+
skip_if(os_is_wsl())
185+
model_code <- "
186+
data {
187+
int N;
188+
}
189+
parameters {
190+
vector[N] y;
191+
real x;
192+
}
193+
model {
194+
x ~ std_normal();
195+
y ~ std_normal();
196+
}
197+
"
198+
mod <- cmdstan_model(write_stan_file(model_code),
199+
compile_model_methods = TRUE)
200+
fit <- mod$sample(data = list(N = 0), chains = 1)
201+
unconstrained <- fit$unconstrain_variables(variables = list(x = 5))
202+
expect_equal(unconstrained, 5)
203+
})

0 commit comments

Comments
 (0)