Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1439,13 +1439,30 @@ struct CaseWhenFunction : ScalarFunction {
}
}

if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
// TODO(ARROW-14105): also apply casts to dictionary indices/values
if (is_dictionary((*values)[1].type->id()) &&
std::all_of(values->begin() + 2, values->end(), [&](const ValueDescr& descr) {
return descr.type->Equals(*(*values)[1].type);
})) {
auto kernel = DispatchExactImpl(this, *values);
DCHECK(kernel);
return kernel;
}

EnsureDictionaryDecoded(values);
if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
for (auto it = values->begin() + 1; it != values->end(); it++) {
it->type = type;
}
ValueDescr* first_arg = &(*values)[1];
const size_t num_args = values->size() - 1;
if (auto type = CommonNumeric(first_arg, num_args)) {
ReplaceTypes(type, first_arg, num_args);
}
if (auto type = CommonBinary(first_arg, num_args)) {
ReplaceTypes(type, first_arg, num_args);
}
if (auto type = CommonTemporal(first_arg, num_args)) {
ReplaceTypes(type, first_arg, num_args);
}
if (HasDecimal(*values)) {
RETURN_NOT_OK(CastDecimalArgs(first_arg, num_args));
}
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
Expand Down Expand Up @@ -1934,9 +1951,20 @@ struct CoalesceFunction : ScalarFunction {
Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;

// TODO(ARROW-14105): also apply casts to dictionary indices/values
if (is_dictionary((*values)[0].type->id()) &&
std::all_of(values->begin() + 1, values->end(), [&](const ValueDescr& descr) {
return descr.type->Equals(*(*values)[0].type);
})) {
auto kernel = DispatchExactImpl(this, *values);
DCHECK(kernel);
return kernel;
}

// Do not DispatchExact here since we want to rescale decimals if necessary
EnsureDictionaryDecoded(values);
if (auto type = CommonNumeric(*values)) {
if (auto type = CommonNumeric(values->data(), values->size())) {
ReplaceTypes(type, values);
}
if (auto type = CommonBinary(values->data(), values->size())) {
Expand Down Expand Up @@ -2244,7 +2272,7 @@ static Status ExecVarWidthCoalesceImpl(KernelContext* ctx, const ExecBatch& batc
}
ArrayData* output = out->mutable_array();
std::unique_ptr<ArrayBuilder> raw_builder;
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder));
RETURN_NOT_OK(raw_builder->Reserve(batch.length));
RETURN_NOT_OK(reserve_data(raw_builder.get()));

Expand Down Expand Up @@ -2388,7 +2416,8 @@ struct CoalesceFunctor<Type, enable_if_base_binary<Type>> {

template <typename Type>
struct CoalesceFunctor<
Type, enable_if_t<is_nested_type<Type>::value && !is_union_type<Type>::value>> {
Type, enable_if_t<(is_nested_type<Type>::value || is_dictionary_type<Type>::value) &&
!is_union_type<Type>::value>> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
for (const auto& datum : batch.values) {
Expand Down Expand Up @@ -2422,7 +2451,7 @@ struct CoalesceFunctor<Type, enable_if_union<Type>> {
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ArrayData* output = out->mutable_array();
std::unique_ptr<ArrayBuilder> raw_builder;
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder));
RETURN_NOT_OK(raw_builder->Reserve(batch.length));

const UnionType& type = checked_cast<const UnionType&>(*out->type());
Expand Down Expand Up @@ -2858,6 +2887,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor<StructType>::Exec);
AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor<DenseUnionType>::Exec);
AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor<SparseUnionType>::Exec);
AddCoalesceKernel(func, Type::DICTIONARY, CoalesceFunctor<DictionaryType>::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand Down
154 changes: 153 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,8 @@ TYPED_TEST(TestCaseWhenDict, Simple) {
}

TYPED_TEST(TestCaseWhenDict, Mixed) {
auto type = dictionary(default_type_instance<TypeParam>(), utf8());
auto index_type = default_type_instance<TypeParam>();
auto type = dictionary(index_type, utf8());
auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
auto dict = R"(["a", null, "bc", "def"])";
Expand All @@ -1119,6 +1120,17 @@ TYPED_TEST(TestCaseWhenDict, Mixed) {
"case_when",
{MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded},
/*result_is_encoded=*/false);

// If we have mismatched dictionary types, we decode (for now)
auto values3_dict =
DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict);
auto values4_dict = DictArrayFromJSON(
dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()),
"[2, 1, null, 0]", dict);
CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1_dict, values3_dict},
/*result_is_encoded=*/false);
CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1_dict, values4_dict},
/*result_is_encoded=*/false);
}

TYPED_TEST(TestCaseWhenDict, NestedSimple) {
Expand Down Expand Up @@ -2088,6 +2100,17 @@ TEST(TestCaseWhen, UnionBoolString) {
TEST(TestCaseWhen, DispatchBest) {
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
{struct_({field("", boolean())}), int64(), int64()});
CheckDispatchBest("case_when",
{struct_({field("", boolean())}), binary(), large_utf8()},
{struct_({field("", boolean())}), large_binary(), large_binary()});
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()},
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND)});
CheckDispatchBest(
"case_when", {struct_({field("", boolean())}), decimal128(38, 0), decimal128(1, 1)},
{struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)});

ASSERT_RAISES(Invalid, CallFunction("case_when", {}));
// Too many/too few conditions
Expand Down Expand Up @@ -2360,6 +2383,132 @@ TYPED_TEST(TestCoalesceList, Errors) {
}));
}

template <typename Type>
class TestCoalesceDict : public ::testing::Test {};

TYPED_TEST_SUITE(TestCoalesceDict, IntegralArrowTypes);

TYPED_TEST(TestCoalesceDict, Simple) {
for (const auto& dict :
{JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
JsonDict{int64(), "[1, null, 2, 3]"},
JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
auto type = dictionary(default_type_instance<TypeParam>(), dict.type);
auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value);
auto values1 = DictArrayFromJSON(type, "[0, null, 3, null]", dict.value);
auto values2 = DictArrayFromJSON(type, "[2, 1, null, null]", dict.value);
auto scalar = DictScalarFromJSON(type, "2", dict.value);

// Easy case: all arguments have the same dictionary
CheckDictionary("coalesce", {values1, values2});
CheckDictionary("coalesce", {values1, values2, values1});
CheckDictionary("coalesce", {values_null, values1});
CheckDictionary("coalesce", {values1, values_null});
CheckDictionary("coalesce", {values1, scalar});
CheckDictionary("coalesce", {values_null, scalar});
CheckDictionary("coalesce", {scalar, values1});
}
}

TYPED_TEST(TestCoalesceDict, Mixed) {
auto index_type = default_type_instance<TypeParam>();
auto type = dictionary(index_type, utf8());
auto dict = R"(["a", null, "bc", "def"])";
auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
auto scalar = ScalarFromJSON(utf8(), R"("bc")");

// If we have mixed dictionary/non-dictionary arguments, we decode dictionaries
CheckDictionary("coalesce", {values1_dict, values2_decoded},
/*result_is_encoded=*/false);
CheckDictionary("coalesce", {values1_decoded, values2_dict},
/*result_is_encoded=*/false);
CheckDictionary("coalesce", {values1_dict, values2_dict, values1_decoded},
/*result_is_encoded=*/false);
CheckDictionary("coalesce", {values_null, values2_dict, values1_decoded},
/*result_is_encoded=*/false);
CheckDictionary("coalesce", {values_null, scalar}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {scalar, values_null}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {values1_dict, scalar}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {scalar, values2_dict}, /*result_is_encoded=*/false);

// If we have mismatched dictionary types, we decode (for now)
auto values3_dict =
DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict);
auto values4_dict = DictArrayFromJSON(
dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()),
"[2, 1, null, 0]", dict);
CheckDictionary("coalesce", {values1_dict, values3_dict}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {values1_dict, values4_dict}, /*result_is_encoded=*/false);
}

TYPED_TEST(TestCoalesceDict, NestedSimple) {
auto index_type = default_type_instance<TypeParam>();
auto inner_type = dictionary(index_type, utf8());
auto type = list(inner_type);
auto dict = R"(["a", null, "bc", "def"])";
auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
DictArrayFromJSON(inner_type, "[]", dict));
auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
auto values1 =
MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
auto values2 =
MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing);
auto scalar =
Datum(std::make_shared<ListScalar>(DictArrayFromJSON(inner_type, "[0, 1]", dict)));

CheckDictionary("coalesce", {values1, values2}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {values1, scalar}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {scalar, values2}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {values_null, values2}, /*result_is_encoded=*/false);
CheckDictionary("coalesce", {values1, values_null}, /*result_is_encoded=*/false);
}

TYPED_TEST(TestCoalesceDict, DifferentDictionaries) {
auto type = dictionary(default_type_instance<TypeParam>(), utf8());
auto dict1 = R"(["a", "", "bc", "def"])";
auto dict2 = R"(["bc", "foo", "", "a"])";
auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1);
auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2);
auto scalar1 = DictScalarFromJSON(type, "0", dict1);
auto scalar2 = DictScalarFromJSON(type, "0", dict2);

CheckDictionary("coalesce", {values1, values2});
CheckDictionary("coalesce", {values1, scalar2});
CheckDictionary("coalesce", {scalar1, values2});
CheckDictionary("coalesce", {values1, scalar2});
CheckDictionary("coalesce", {values1_null, values2});
CheckDictionary("coalesce", {values1, values2_null});

// Test dictionaries with nulls (where decoding before/after calling coalesce changes
// the results)
dict1 = R"(["a", null, "bc", "def"])";
dict2 = R"(["bc", "foo", null, "a"])";
values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1);
values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2);
scalar1 = DictScalarFromJSON(type, "0", dict1);

// Note this is sensitive to the implementation. Nulls are emitted here
// because a non-null index mapped to a null dictionary value and was emitted
// as a null (instead of encoding null in the dictionary)
CheckScalarNonRecursive(
"coalesce", {values1, values2},
DictArrayFromJSON(type, "[null, 0, 1, null]", R"(["a", "def"])"));
CheckScalarNonRecursive("coalesce", {values1, scalar1},
DictArrayFromJSON(type, "[0, 0, 1, null]", R"(["a", "def"])"));
// The dictionary gets preserved since a leading non-null scalar just gets
// broadcasted and returned without going through the rest of the kernel
// implementation
CheckScalarNonRecursive("coalesce", {scalar1, values1},
DictArrayFromJSON(type, "[0, 0, 0, 0]", dict1));
}

TEST(TestCoalesce, Null) {
auto type = null();
auto scalar_null = ScalarFromJSON(type, "null");
Expand Down Expand Up @@ -2716,6 +2865,9 @@ TEST(TestCoalesce, DispatchBest) {
sparse_union({field("a", boolean())}),
dense_union({field("a", boolean())}),
});
CheckDispatchBest("coalesce",
{dictionary(int8(), binary()), dictionary(int16(), large_utf8())},
{large_binary(), large_binary()});
}

template <typename Type>
Expand Down
6 changes: 0 additions & 6 deletions r/R/dplyr-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ nse_funcs$coalesce <- function(...) {
arg <- Expression$scalar(arg)
}

# coalesce doesn't yet support factors/dictionaries
# TODO: remove this after ARROW-14167 is merged
if (nse_funcs$is.factor(arg)) {
warning("Dictionaries (in R: factors) are currently converted to strings (characters) in coalesce", call. = FALSE)
}

if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) {
# store the NA_real_ in the same type as arg to avoid avoid casting
# smaller float types to larger float types
Expand Down
37 changes: 20 additions & 17 deletions r/tests/testthat/test-dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,26 @@ test_that("coalesce()", {
df
)

# factor
df_fct <- df %>%
transmute(across(everything(), ~ factor(.x, levels = c("a", "b", "c"))))
compare_dplyr_binding(
.input %>%
mutate(
cw = coalesce(w),
cz = coalesce(z),
cwx = coalesce(w, x),
cwxy = coalesce(w, x, y),
cwxyz = coalesce(w, x, y, z)
) %>%
collect() %>%
# Arrow coalesce() kernel does not preserve unused factor levels,
# so reset the levels of all the factor columns to make the test pass
# (ARROW-14649)
transmute(across(where(is.factor), ~ factor(.x, levels = c("a", "b", "c")))),
df_fct
)

# integer
df <- tibble(
w = c(NA_integer_, NA_integer_, NA_integer_),
Expand Down Expand Up @@ -383,23 +403,6 @@ test_that("coalesce()", {
df
)

# factors
# TODO: remove the mutate + warning after ARROW-14167 is merged and Arrow
# supports factors in coalesce
df <- tibble(
x = factor("a", levels = c("a", "z")),
y = factor("b", levels = c("a", "b", "c"))
)
compare_dplyr_binding(
.input %>%
mutate(c = coalesce(x, y)) %>%
collect() %>%
# This is a no-op on the Arrow side, but necessary to make the results equal
mutate(c = as.character(c)),
df,
warning = "Dictionaries .* are currently converted to strings .* in coalesce"
)

# no arguments
expect_error(
nse_funcs$coalesce(),
Expand Down