Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions r/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
* String operations: `strsplit()` and `str_split()`; `strptime()`; `paste()`, `paste0()`, and `str_c()`; `substr()` and `str_sub()`; `str_like()`; `str_pad()`; `stri_reverse()`
* Date/time operations: `lubridate` methods such as `year()`, `month()`, `wday()`, and so on
* Math: `log()`, trigonometry (`sin()`, `cos()`, et al.), `abs()`, `sign()`, `pmin()`/`pmax()`
* Conditional: `ifelse()` and `if_else()` (fixed-precision decimal numbers do not yet work and factors/dictionaries are converted to character strings); `case_when()` (currently works with numeric data types but not character strings, factors/dictionaries, or lists/structs)
* `is.*` functions are supported and can be used inside `relocate()`

* The print method for `arrow_dplyr_query` now includes the expression and the resulting type of columns derived by `mutate()`.
Expand Down
35 changes: 34 additions & 1 deletion r/R/dplyr-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,39 @@ nse_funcs$if_else <- function(condition, true, false, missing = NULL){

# Although base R ifelse allows `yes` and `no` to be different classes
#
nse_funcs$ifelse <- function(test, yes, no){
nse_funcs$ifelse <- function(test, yes, no) {
nse_funcs$if_else(condition = test, true = yes, false = no)
}

nse_funcs$case_when <- function(...) {
formulas <- list2(...)
n <- length(formulas)
if (n == 0) {
abort("No cases provided in case_when()")
}
query <- vector("list", n)
value <- vector("list", n)
mask <- caller_env()
for (i in seq_len(n)) {
f <- formulas[[i]]
if (!inherits(f, "formula")) {
abort("Each argument to case_when() must be a two-sided formula")
}
query[[i]] <- arrow_eval(f[[2]], mask)
value[[i]] <- arrow_eval(f[[3]], mask)
if (!nse_funcs$is.logical(query[[i]])) {
abort("Left side of each formula in case_when() must be a logical expression")
}
}
build_expr(
"case_when",
args = c(
build_expr(
"make_struct",
args = query,
options = list(field_names = as.character(seq_along(query)))
),
value
)
)
}
7 changes: 7 additions & 0 deletions r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}

if (func_name == "make_struct") {
using Options = arrow::compute::MakeStructOptions;
// TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options
return std::make_shared<Options>(
cpp11::as_cpp<std::vector<std::string>>(options["field_names"]));
}

if (func_name == "match_substring" || func_name == "match_substring_regex" ||
func_name == "find_substring" || func_name == "find_substring_regex" ||
func_name == "match_like") {
Expand Down
119 changes: 119 additions & 0 deletions r/tests/testthat/test-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -1225,3 +1225,122 @@ test_that("if_else and ifelse", {
tbl
)
})

test_that("case_when()", {
expect_dplyr_equal(
input %>%
transmute(cw = case_when(lgl ~ dbl, !false ~ dbl + dbl2)) %>%
collect(),
tbl
)
expect_dplyr_equal(
input %>%
mutate(cw = case_when(int > 5 ~ 1, TRUE ~ 0)) %>%
collect(),
tbl
)
expect_dplyr_equal(
input %>%
transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>%
collect(),
tbl
)
expect_dplyr_equal(
input %>%
filter(case_when(
dbl + int - 1.1 == dbl2 ~ TRUE,
NA ~ NA,
TRUE ~ FALSE
) & !is.na(dbl2)) %>%
collect(),
tbl
)

# dplyr::case_when() errors if values on right side of formulas do not have
# exactly the same type, but the Arrow case_when kernel allows compatible types
expect_equal(
tbl %>%
mutate(i64 = as.integer64(1e10)) %>%
Table$create() %>%
transmute(cw = case_when(
is.na(fct) ~ int,
is.na(chr) ~ dbl,
TRUE ~ i64
)) %>%
collect(),
tbl %>%
transmute(
cw = ifelse(is.na(fct), int, ifelse(is.na(chr), dbl, 1e10))
)
)

# expected errors (which are caught by abandon_ship() and changed to warnings)
# TODO: Find a way to test these directly without abandon_ship() interfering
expect_error(
# no cases
expect_warning(
tbl %>%
Table$create() %>%
transmute(cw = case_when()),
"case_when"
)
)
expect_error(
# argument not a formula
expect_warning(
tbl %>%
Table$create() %>%
transmute(cw = case_when(TRUE ~ FALSE, TRUE)),
"case_when"
)
)
expect_error(
# non-logical R scalar on left side of formula
expect_warning(
tbl %>%
Table$create() %>%
transmute(cw = case_when(0L ~ FALSE, TRUE ~ FALSE)),
"case_when"
)
)
expect_error(
# non-logical Arrow column reference on left side of formula
expect_warning(
tbl %>%
Table$create() %>%
transmute(cw = case_when(int ~ FALSE)),
"case_when"
)
)
expect_error(
# non-logical Arrow expression on left side of formula
expect_warning(
tbl %>%
Table$create() %>%
transmute(cw = case_when(dbl + 3.14159 ~ TRUE)),
"case_when"
)
)

skip("case_when does not yet support with variable-width types (ARROW-13222)")
expect_dplyr_equal(
input %>%
transmute(cw = case_when(lgl ~ "abc")) %>%
collect(),
tbl
)
expect_dplyr_equal(
input %>%
transmute(cw = case_when(lgl ~ verses, !false ~ paste(chr, chr))) %>%
collect(),
tbl
)
expect_dplyr_equal(
input %>%
mutate(
cw = paste0(case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct), "!")
) %>%
collect(),
tbl
)
})