@@ -39,6 +39,36 @@ functions {
3939 tuple(int, tuple(array[] vector, array[] vector)) rtn_nest_tuple_vec_array(tuple(int, tuple(array[] vector, array[] vector)) x) { return x; }
4040 tuple(int, tuple(array[] row_vector, array[] row_vector)) rtn_nest_tuple_rowvec_array(tuple(int, tuple(array[] row_vector, array[] row_vector)) x) { return x; }
4141 tuple(int, tuple(array[] matrix, array[] matrix)) rtn_nest_tuple_matrix_array(tuple(int, tuple(array[] matrix, array[] matrix)) x) { return x; }
42+
43+ complex rtn_complex(complex x) { return x; }
44+ complex_vector rtn_complex_vec(complex_vector x) { return x; }
45+ complex_row_vector rtn_complex_rowvec(complex_row_vector x) { return x; }
46+ complex_matrix rtn_complex_matrix(complex_matrix x) { return x; }
47+
48+ array[] complex rtn_complex_array(array[] complex x) { return x; }
49+ array[] complex_vector rtn_complex_vec_array(array[] complex_vector x) { return x; }
50+ array[] complex_row_vector rtn_complex_rowvec_array(array[] complex_row_vector x) { return x; }
51+ array[] complex_matrix rtn_complex_matrix_array(array[] complex_matrix x) { return x; }
52+
53+ tuple(complex, complex) rtn_tuple_complex(tuple(complex, complex) x) { return x; }
54+ tuple(complex_vector, complex_vector) rtn_tuple_complex_vec(tuple(complex_vector, complex_vector) x) { return x; }
55+ tuple(complex_row_vector, complex_row_vector) rtn_tuple_complex_rowvec(tuple(complex_row_vector, complex_row_vector) x) { return x; }
56+ tuple(complex_matrix, complex_matrix) rtn_tuple_complex_matrix(tuple(complex_matrix, complex_matrix) x) { return x; }
57+
58+ tuple(array[] complex, array[] complex) rtn_tuple_complex_array(tuple(array[] complex, array[] complex) x) { return x; }
59+ tuple(array[] complex_vector, array[] complex_vector) rtn_tuple_complex_vec_array(tuple(array[] complex_vector, array[] complex_vector) x) { return x; }
60+ tuple(array[] complex_row_vector, array[] complex_row_vector) rtn_tuple_complex_rowvec_array(tuple(array[] complex_row_vector, array[] complex_row_vector) x) { return x; }
61+ tuple(array[] complex_matrix, array[] complex_matrix) rtn_tuple_complex_matrix_array(tuple(array[] complex_matrix, array[] complex_matrix) x) { return x; }
62+
63+ tuple(int, tuple(complex, complex)) rtn_nest_tuple_complex(tuple(int, tuple(complex, complex)) x) { return x; }
64+ tuple(int, tuple(complex_vector, complex_vector)) rtn_nest_tuple_complex_vec(tuple(int, tuple(complex_vector, complex_vector)) x) { return x; }
65+ tuple(int, tuple(complex_row_vector, complex_row_vector)) rtn_nest_tuple_complex_rowvec(tuple(int, tuple(complex_row_vector, complex_row_vector)) x) { return x; }
66+ tuple(int, tuple(complex_matrix, complex_matrix)) rtn_nest_tuple_complex_matrix(tuple(int, tuple(complex_matrix, complex_matrix)) x) { return x; }
67+
68+ tuple(int, tuple(array[] complex, array[] complex)) rtn_nest_tuple_complex_array(tuple(int, tuple(array[] complex, array[] complex)) x) { return x; }
69+ tuple(int, tuple(array[] complex_vector, array[] complex_vector)) rtn_nest_tuple_complex_vec_array(tuple(int, tuple(array[] complex_vector, array[] complex_vector)) x) { return x; }
70+ tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) rtn_nest_tuple_complex_rowvec_array(tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) x) { return x; }
71+ tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) rtn_nest_tuple_complex_matrix_array(tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) x) { return x; }
4272}"
4373stan_prog <- paste(function_decl ,
4474 paste(readLines(testing_stan_file(" bernoulli" )),
@@ -147,6 +177,90 @@ test_that("Functions handle types correctly", {
147177 expect_equal(mod $ functions $ rtn_nest_tuple_matrix_array(nest_tuple_matrix_array ), nest_tuple_matrix_array )
148178})
149179
180+ test_that(" Functions handle complex types correctly" , {
181+ skip_if(os_is_wsl())
182+
183+ # ## Scalar
184+
185+ complex_scalar <- complex (real = 2.1 , imaginary = 21.3 )
186+
187+ expect_equal(mod $ functions $ rtn_complex(complex_scalar ), complex_scalar )
188+
189+ # ## Container
190+
191+ complex_vec <- complex (real = c(2 ,1.5 ,0.11 , 1.2 ), imaginary = c(11.2 ,21.5 ,6.1 ,3.2 ))
192+ complex_rowvec <- t(complex_vec )
193+ complex_matrix <- matrix (complex_vec , nrow = 2 , ncol = 2 )
194+
195+ expect_equal(mod $ functions $ rtn_complex_vec(complex_vec ), complex_vec )
196+ expect_equal(mod $ functions $ rtn_complex_rowvec(complex_rowvec ), complex_rowvec )
197+ expect_equal(mod $ functions $ rtn_complex_matrix(complex_matrix ), complex_matrix )
198+ expect_equal(mod $ functions $ rtn_complex_array(complex_vec ), complex_vec )
199+
200+ # ## Array of Container
201+
202+ complex_vec_array <- list (complex_vec , complex_vec * 2 , complex_vec + 0.1 )
203+ complex_rowvec_array <- list (complex_rowvec , complex_rowvec * 2 , complex_rowvec + 0.1 )
204+ complex_matrix_array <- list (complex_matrix , complex_matrix * 2 , complex_matrix + 0.1 )
205+
206+ expect_equal(mod $ functions $ rtn_complex_vec_array(complex_vec_array ), complex_vec_array )
207+ expect_equal(mod $ functions $ rtn_complex_rowvec_array(complex_rowvec_array ), complex_rowvec_array )
208+ expect_equal(mod $ functions $ rtn_complex_matrix_array(complex_matrix_array ), complex_matrix_array )
209+
210+ # ## Tuple of Scalar
211+
212+ tuple_complex <- list (complex_vec [1 ], complex_vec [2 ])
213+ expect_equal(mod $ functions $ rtn_tuple_complex(tuple_complex ), tuple_complex )
214+
215+ # ## Tuple of Container
216+
217+ tuple_complex_vec <- list (complex_vec , complex_vec * 1.2 )
218+ tuple_complex_rowvec <- list (complex_rowvec , complex_rowvec * 0.5 )
219+ tuple_complex_matrix <- list (complex_matrix , complex_matrix * 10.2 )
220+
221+ expect_equal(mod $ functions $ rtn_tuple_complex_array(tuple_complex_vec ), tuple_complex_vec )
222+ expect_equal(mod $ functions $ rtn_tuple_complex_vec(tuple_complex_vec ), tuple_complex_vec )
223+ expect_equal(mod $ functions $ rtn_tuple_complex_rowvec(tuple_complex_rowvec ), tuple_complex_rowvec )
224+ expect_equal(mod $ functions $ rtn_tuple_complex_matrix(tuple_complex_matrix ), tuple_complex_matrix )
225+
226+ # ## Tuple of Container Arrays
227+
228+ tuple_complex_vec_array <- list (complex_vec_array , complex_vec_array )
229+ tuple_complex_rowvec_array <- list (complex_rowvec_array , complex_rowvec_array )
230+ tuple_complex_matrix_array <- list (complex_matrix_array , complex_matrix_array )
231+
232+ expect_equal(mod $ functions $ rtn_tuple_complex_vec_array(tuple_complex_vec_array ), tuple_complex_vec_array )
233+ expect_equal(mod $ functions $ rtn_tuple_complex_rowvec_array(tuple_complex_rowvec_array ), tuple_complex_rowvec_array )
234+ expect_equal(mod $ functions $ rtn_tuple_complex_matrix_array(tuple_complex_matrix_array ), tuple_complex_matrix_array )
235+
236+ # ## Nested Tuple of Scalar
237+
238+ nest_tuple_complex <- list (31 , tuple_complex )
239+ expect_equal(mod $ functions $ rtn_nest_tuple_complex(nest_tuple_complex ), nest_tuple_complex )
240+
241+ # ## Nested Tuple of Container
242+
243+ nest_tuple_complex_vec <- list (12 , tuple_complex_vec )
244+ nest_tuple_complex_rowvec <- list (2 , tuple_complex_rowvec )
245+ nest_tuple_complex_matrix <- list (- 23 , tuple_complex_matrix )
246+ nest_tuple_complex_array <- list (21 , tuple_complex_vec )
247+
248+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_array(nest_tuple_complex_vec ), nest_tuple_complex_vec )
249+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_vec(nest_tuple_complex_vec ), nest_tuple_complex_vec )
250+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_rowvec(nest_tuple_complex_rowvec ), nest_tuple_complex_rowvec )
251+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_matrix(nest_tuple_complex_matrix ), nest_tuple_complex_matrix )
252+
253+ # ## Nested Tuple of Container Arrays
254+
255+ nest_tuple_complex_vec_array <- list (- 21 , tuple_complex_vec_array )
256+ nest_tuple_complex_rowvec_array <- list (1000 , tuple_complex_rowvec_array )
257+ nest_tuple_complex_matrix_array <- list (0 , tuple_complex_matrix_array )
258+
259+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_vec_array(nest_tuple_complex_vec_array ), nest_tuple_complex_vec_array )
260+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_rowvec_array(nest_tuple_complex_rowvec_array ), nest_tuple_complex_rowvec_array )
261+ expect_equal(mod $ functions $ rtn_nest_tuple_complex_matrix_array(nest_tuple_complex_matrix_array ), nest_tuple_complex_matrix_array )
262+ })
263+
150264test_that(" Functions can be exposed in fit object" , {
151265 skip_if(os_is_wsl())
152266 fit $ expose_functions(verbose = TRUE )
0 commit comments