Skip to content

Commit ede72e3

Browse files
authored
Merge pull request #861 from stan-dev/complex-support
Add support/tests for exporting functions with complex types
2 parents a2af167 + ccab663 commit ede72e3

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

R/utils.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,8 @@ get_function_name <- function(fun_start, fun_end, model_lines) {
833833
"double",
834834
"Eigen::Matrix<(.*)>",
835835
"std::vector<(.*)>",
836-
"std::tuple<(.*)>"
836+
"std::tuple<(.*)>",
837+
"std::complex<(.*)>"
837838
)
838839
pattern <- paste0(
839840
# Only match if the type occurs at start of string

tests/testthat/test-model-expose-functions.R

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,36 @@ functions {
3939
tuple(int, tuple(array[] vector, array[] vector)) rtn_nest_tuple_vec_array(tuple(int, tuple(array[] vector, array[] vector)) x) { return x; }
4040
tuple(int, tuple(array[] row_vector, array[] row_vector)) rtn_nest_tuple_rowvec_array(tuple(int, tuple(array[] row_vector, array[] row_vector)) x) { return x; }
4141
tuple(int, tuple(array[] matrix, array[] matrix)) rtn_nest_tuple_matrix_array(tuple(int, tuple(array[] matrix, array[] matrix)) x) { return x; }
42+
43+
complex rtn_complex(complex x) { return x; }
44+
complex_vector rtn_complex_vec(complex_vector x) { return x; }
45+
complex_row_vector rtn_complex_rowvec(complex_row_vector x) { return x; }
46+
complex_matrix rtn_complex_matrix(complex_matrix x) { return x; }
47+
48+
array[] complex rtn_complex_array(array[] complex x) { return x; }
49+
array[] complex_vector rtn_complex_vec_array(array[] complex_vector x) { return x; }
50+
array[] complex_row_vector rtn_complex_rowvec_array(array[] complex_row_vector x) { return x; }
51+
array[] complex_matrix rtn_complex_matrix_array(array[] complex_matrix x) { return x; }
52+
53+
tuple(complex, complex) rtn_tuple_complex(tuple(complex, complex) x) { return x; }
54+
tuple(complex_vector, complex_vector) rtn_tuple_complex_vec(tuple(complex_vector, complex_vector) x) { return x; }
55+
tuple(complex_row_vector, complex_row_vector) rtn_tuple_complex_rowvec(tuple(complex_row_vector, complex_row_vector) x) { return x; }
56+
tuple(complex_matrix, complex_matrix) rtn_tuple_complex_matrix(tuple(complex_matrix, complex_matrix) x) { return x; }
57+
58+
tuple(array[] complex, array[] complex) rtn_tuple_complex_array(tuple(array[] complex, array[] complex) x) { return x; }
59+
tuple(array[] complex_vector, array[] complex_vector) rtn_tuple_complex_vec_array(tuple(array[] complex_vector, array[] complex_vector) x) { return x; }
60+
tuple(array[] complex_row_vector, array[] complex_row_vector) rtn_tuple_complex_rowvec_array(tuple(array[] complex_row_vector, array[] complex_row_vector) x) { return x; }
61+
tuple(array[] complex_matrix, array[] complex_matrix) rtn_tuple_complex_matrix_array(tuple(array[] complex_matrix, array[] complex_matrix) x) { return x; }
62+
63+
tuple(int, tuple(complex, complex)) rtn_nest_tuple_complex(tuple(int, tuple(complex, complex)) x) { return x; }
64+
tuple(int, tuple(complex_vector, complex_vector)) rtn_nest_tuple_complex_vec(tuple(int, tuple(complex_vector, complex_vector)) x) { return x; }
65+
tuple(int, tuple(complex_row_vector, complex_row_vector)) rtn_nest_tuple_complex_rowvec(tuple(int, tuple(complex_row_vector, complex_row_vector)) x) { return x; }
66+
tuple(int, tuple(complex_matrix, complex_matrix)) rtn_nest_tuple_complex_matrix(tuple(int, tuple(complex_matrix, complex_matrix)) x) { return x; }
67+
68+
tuple(int, tuple(array[] complex, array[] complex)) rtn_nest_tuple_complex_array(tuple(int, tuple(array[] complex, array[] complex)) x) { return x; }
69+
tuple(int, tuple(array[] complex_vector, array[] complex_vector)) rtn_nest_tuple_complex_vec_array(tuple(int, tuple(array[] complex_vector, array[] complex_vector)) x) { return x; }
70+
tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) rtn_nest_tuple_complex_rowvec_array(tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) x) { return x; }
71+
tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) rtn_nest_tuple_complex_matrix_array(tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) x) { return x; }
4272
}"
4373
stan_prog <- paste(function_decl,
4474
paste(readLines(testing_stan_file("bernoulli")),
@@ -147,6 +177,90 @@ test_that("Functions handle types correctly", {
147177
expect_equal(mod$functions$rtn_nest_tuple_matrix_array(nest_tuple_matrix_array), nest_tuple_matrix_array)
148178
})
149179

180+
test_that("Functions handle complex types correctly", {
181+
skip_if(os_is_wsl())
182+
183+
### Scalar
184+
185+
complex_scalar <- complex(real = 2.1, imaginary = 21.3)
186+
187+
expect_equal(mod$functions$rtn_complex(complex_scalar), complex_scalar)
188+
189+
### Container
190+
191+
complex_vec <- complex(real = c(2,1.5,0.11, 1.2), imaginary = c(11.2,21.5,6.1,3.2))
192+
complex_rowvec <- t(complex_vec)
193+
complex_matrix <- matrix(complex_vec, nrow=2, ncol=2)
194+
195+
expect_equal(mod$functions$rtn_complex_vec(complex_vec), complex_vec)
196+
expect_equal(mod$functions$rtn_complex_rowvec(complex_rowvec), complex_rowvec)
197+
expect_equal(mod$functions$rtn_complex_matrix(complex_matrix), complex_matrix)
198+
expect_equal(mod$functions$rtn_complex_array(complex_vec), complex_vec)
199+
200+
### Array of Container
201+
202+
complex_vec_array <- list(complex_vec, complex_vec * 2, complex_vec + 0.1)
203+
complex_rowvec_array <- list(complex_rowvec, complex_rowvec * 2, complex_rowvec + 0.1)
204+
complex_matrix_array <- list(complex_matrix, complex_matrix * 2, complex_matrix + 0.1)
205+
206+
expect_equal(mod$functions$rtn_complex_vec_array(complex_vec_array), complex_vec_array)
207+
expect_equal(mod$functions$rtn_complex_rowvec_array(complex_rowvec_array), complex_rowvec_array)
208+
expect_equal(mod$functions$rtn_complex_matrix_array(complex_matrix_array), complex_matrix_array)
209+
210+
### Tuple of Scalar
211+
212+
tuple_complex <- list(complex_vec[1], complex_vec[2])
213+
expect_equal(mod$functions$rtn_tuple_complex(tuple_complex), tuple_complex)
214+
215+
### Tuple of Container
216+
217+
tuple_complex_vec <- list(complex_vec, complex_vec * 1.2)
218+
tuple_complex_rowvec <- list(complex_rowvec, complex_rowvec * 0.5)
219+
tuple_complex_matrix <- list(complex_matrix, complex_matrix * 10.2)
220+
221+
expect_equal(mod$functions$rtn_tuple_complex_array(tuple_complex_vec), tuple_complex_vec)
222+
expect_equal(mod$functions$rtn_tuple_complex_vec(tuple_complex_vec), tuple_complex_vec)
223+
expect_equal(mod$functions$rtn_tuple_complex_rowvec(tuple_complex_rowvec), tuple_complex_rowvec)
224+
expect_equal(mod$functions$rtn_tuple_complex_matrix(tuple_complex_matrix), tuple_complex_matrix)
225+
226+
### Tuple of Container Arrays
227+
228+
tuple_complex_vec_array <- list(complex_vec_array, complex_vec_array)
229+
tuple_complex_rowvec_array <- list(complex_rowvec_array, complex_rowvec_array)
230+
tuple_complex_matrix_array <- list(complex_matrix_array, complex_matrix_array)
231+
232+
expect_equal(mod$functions$rtn_tuple_complex_vec_array(tuple_complex_vec_array), tuple_complex_vec_array)
233+
expect_equal(mod$functions$rtn_tuple_complex_rowvec_array(tuple_complex_rowvec_array), tuple_complex_rowvec_array)
234+
expect_equal(mod$functions$rtn_tuple_complex_matrix_array(tuple_complex_matrix_array), tuple_complex_matrix_array)
235+
236+
### Nested Tuple of Scalar
237+
238+
nest_tuple_complex <- list(31, tuple_complex)
239+
expect_equal(mod$functions$rtn_nest_tuple_complex(nest_tuple_complex), nest_tuple_complex)
240+
241+
### Nested Tuple of Container
242+
243+
nest_tuple_complex_vec <- list(12, tuple_complex_vec)
244+
nest_tuple_complex_rowvec <- list(2, tuple_complex_rowvec)
245+
nest_tuple_complex_matrix <- list(-23, tuple_complex_matrix)
246+
nest_tuple_complex_array <- list(21, tuple_complex_vec)
247+
248+
expect_equal(mod$functions$rtn_nest_tuple_complex_array(nest_tuple_complex_vec), nest_tuple_complex_vec)
249+
expect_equal(mod$functions$rtn_nest_tuple_complex_vec(nest_tuple_complex_vec), nest_tuple_complex_vec)
250+
expect_equal(mod$functions$rtn_nest_tuple_complex_rowvec(nest_tuple_complex_rowvec), nest_tuple_complex_rowvec)
251+
expect_equal(mod$functions$rtn_nest_tuple_complex_matrix(nest_tuple_complex_matrix), nest_tuple_complex_matrix)
252+
253+
### Nested Tuple of Container Arrays
254+
255+
nest_tuple_complex_vec_array <- list(-21, tuple_complex_vec_array)
256+
nest_tuple_complex_rowvec_array <- list(1000, tuple_complex_rowvec_array)
257+
nest_tuple_complex_matrix_array <- list(0, tuple_complex_matrix_array)
258+
259+
expect_equal(mod$functions$rtn_nest_tuple_complex_vec_array(nest_tuple_complex_vec_array), nest_tuple_complex_vec_array)
260+
expect_equal(mod$functions$rtn_nest_tuple_complex_rowvec_array(nest_tuple_complex_rowvec_array), nest_tuple_complex_rowvec_array)
261+
expect_equal(mod$functions$rtn_nest_tuple_complex_matrix_array(nest_tuple_complex_matrix_array), nest_tuple_complex_matrix_array)
262+
})
263+
150264
test_that("Functions can be exposed in fit object", {
151265
skip_if(os_is_wsl())
152266
fit$expose_functions(verbose = TRUE)

0 commit comments

Comments
 (0)