Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion r/R/compute.R
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ cast_options <- function(safe = TRUE, ...) {
#' @return `NULL`, invisibly
#' @export
#'
#' @examplesIf arrow_with_dataset()
#' @examplesIf arrow_with_dataset() && identical(Sys.getenv("NOT_CRAN"), "true")
#' library(dplyr, warn.conflicts = FALSE)
#'
#' some_model <- lm(mpg ~ disp + cyl, data = mtcars)
Expand Down Expand Up @@ -385,6 +385,13 @@ register_scalar_function <- function(name, fun, in_type, out_type,
update_cache = TRUE
)

# User-defined functions require some special handling
# in the query engine which currently require an opt-in using
# the R_ARROW_COLLECT_WITH_UDF environment variable while this
# behaviour is stabilized.
# TODO(ARROW-17178) remove the need for this!
Sys.setenv(R_ARROW_COLLECT_WITH_UDF = "true")

invisible(NULL)
}

Expand Down
9 changes: 8 additions & 1 deletion r/R/table.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,5 +331,12 @@ as_arrow_table.arrow_dplyr_query <- function(x, ...) {
# See query-engine.R for ExecPlan/Nodes
plan <- ExecPlan$create()
final_node <- plan$Build(x)
plan$Run(final_node, as_table = TRUE)

run_with_event_loop <- identical(
Sys.getenv("R_ARROW_COLLECT_WITH_UDF", ""),
"true"
)

result <- plan$Run(final_node, as_table = run_with_event_loop)
as_arrow_table(result)
}
2 changes: 1 addition & 1 deletion r/man/register_scalar_function.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 40 additions & 11 deletions r/tests/testthat/test-compute.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,21 @@ test_that("arrow_scalar_function() works with auto_convert = TRUE", {

test_that("register_scalar_function() adds a compute function to the registry", {
skip_if_not(CanRunWithCapturedR())
# TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has
# occasional valgrind errors
skip_on_linux_devel()

register_scalar_function(
"times_32",
function(context, x) x * 32.0,
int32(), float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
on.exit({
unregister_binding("times_32", update_cache = TRUE)
# TODO(ARROW-17178) remove the need for this!
Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")
})

expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions))
expect_true("times_32" %in% list_compute_functions())
Expand Down Expand Up @@ -120,9 +127,11 @@ test_that("arrow_scalar_function() with bad return type errors", {
int32(),
float64()
)
on.exit(
on.exit({
unregister_binding("times_32_bad_return_type_array", update_cache = TRUE)
)
# TODO(ARROW-17178) remove the need for this!
Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")
})

expect_error(
call_function("times_32_bad_return_type_array", Array$create(1L)),
Expand All @@ -135,17 +144,19 @@ test_that("arrow_scalar_function() with bad return type errors", {
int32(),
float64()
)
on.exit(
on.exit({
unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE)
)
# TODO(ARROW-17178) remove the need for this!
Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")
})

expect_error(
call_function("times_32_bad_return_type_scalar", Array$create(1L)),
"Expected return Array or Scalar with type 'double'"
)
})

test_that("register_user_defined_function() can register multiple kernels", {
test_that("register_scalar_function() can register multiple kernels", {
skip_if_not(CanRunWithCapturedR())

register_scalar_function(
Expand All @@ -155,7 +166,11 @@ test_that("register_user_defined_function() can register multiple kernels", {
out_type = function(in_types) in_types[[1]],
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
on.exit({
unregister_binding("times_32", update_cache = TRUE)
# TODO(ARROW-17178) remove the need for this!
Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")
})

expect_equal(
call_function("times_32", Scalar$create(1L, int32())),
Expand All @@ -173,7 +188,10 @@ test_that("register_user_defined_function() can register multiple kernels", {
)
})

test_that("register_user_defined_function() errors for unsupported specifications", {
test_that("register_scalar_function() errors for unsupported specifications", {
# TODO(ARROW-17178) remove the need for this!
on.exit(Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF"))

expect_error(
register_scalar_function(
"no_kernels",
Expand Down Expand Up @@ -208,7 +226,10 @@ test_that("register_user_defined_function() errors for unsupported specification
test_that("user-defined functions work during multi-threaded execution", {
skip_if_not(CanRunWithCapturedR())
skip_if_not_available("dataset")
# Snappy has a UBSan issue: https://github.com/google/snappy/pull/148
# Skip on linux devel because:
# TODO(ARROW-17283): Snappy has a UBSan issue that is fixed in the dev version
# TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has
# occasional valgrind errors
skip_on_linux_devel()

n_rows <- 10000
Expand All @@ -235,7 +256,11 @@ test_that("user-defined functions work during multi-threaded execution", {
float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
on.exit({
unregister_binding("times_32", update_cache = TRUE)
# TODO(ARROW-17178) remove the need for this!
Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")
})

# check a regular collect()
result <- open_dataset(tf_dataset) %>%
Expand Down Expand Up @@ -268,7 +293,11 @@ test_that("user-defined error when called from an unsupported context", {
float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
on.exit({
unregister_binding("times_32", update_cache = TRUE)
# TODO(ARROW-17178) remove the need for this!
Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")
})

stream_plan_with_udf <- function() {
record_batch(a = 1:1000) %>%
Expand Down