Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() {
// work above

Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) {
return descrs[0];
ValueDescr result = descrs.front();
result.shape = GetBroadcastShape(descrs);
return result;
}

void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
Expand Down
273 changes: 273 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,22 @@ void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t l
}
}

// Specialized helper to copy a single value from a source array. Allows avoiding
// repeatedly calling MayHaveNulls and Buffer::data() which have internal checks that
// add up when called in a loop.
template <typename Type>
void CopyOneArrayValue(const DataType& type, const uint8_t* in_valid,
const uint8_t* in_values, const int64_t in_offset,
uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset) {
if (out_valid) {
BitUtil::SetBitTo(out_valid, out_offset,
!in_valid || BitUtil::GetBit(in_valid, in_offset));
}
CopyFixedWidth<Type>::CopyArray(type, in_values, in_offset, /*length=*/1, out_values,
out_offset);
}

struct CaseWhenFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Expand Down Expand Up @@ -1372,6 +1388,221 @@ struct CaseWhenFunctor<NullType> {
}
};

struct CoalesceFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
EnsureDictionaryDecoded(values);
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
}
};

// Implement a 'coalesce' (SQL) operator for any number of scalar inputs
Status ExecScalarCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
if (datum.scalar()->is_valid) {
*out = datum;
break;
}
}
return Status::OK();
}

// Helper: copy from a source datum into all null slots of the output
template <typename Type>
void CopyValuesAllValid(Datum source, uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset, const int64_t length) {
BitBlockCounter counter(out_valid, out_offset, length);
int64_t offset = 0;
while (offset < length) {
const auto block = counter.NextWord();
if (block.NoneSet()) {
CopyValues<Type>(source, offset, block.length, out_valid, out_values,
out_offset + offset);
} else if (!block.AllSet()) {
for (int64_t j = 0; j < block.length; ++j) {
if (!BitUtil::GetBit(out_valid, out_offset + offset + j)) {
CopyValues<Type>(source, offset + j, 1, out_valid, out_values,
out_offset + offset + j);
}
}
}
offset += block.length;
}
}

// Helper: zero the values buffer of the output wherever the slot is null
void InitializeNullSlots(const DataType& type, uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset, const int64_t length) {
BitBlockCounter counter(out_valid, out_offset, length);
int64_t offset = 0;
auto bit_width = checked_cast<const FixedWidthType&>(type).bit_width();
auto byte_width = BitUtil::BytesForBits(bit_width);
while (offset < length) {
const auto block = counter.NextWord();
if (block.NoneSet()) {
if (bit_width == 1) {
BitUtil::SetBitsTo(out_values, out_offset + offset, block.length, false);
} else {
std::memset(out_values + (out_offset + offset) * byte_width, 0x00,
byte_width * block.length);
}
} else if (!block.AllSet()) {
for (int64_t j = 0; j < block.length; ++j) {
if (BitUtil::GetBit(out_valid, out_offset + offset + j)) continue;
if (bit_width == 1) {
BitUtil::ClearBit(out_values, out_offset + offset + j);
} else {
std::memset(out_values + (out_offset + offset + j) * byte_width, 0x00,
byte_width);
}
}
}
offset += block.length;
}
}

// Implement 'coalesce' for any mix of scalar/array arguments for any fixed-width type
template <typename Type>
Status ExecArrayCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ArrayData* output = out->mutable_array();
const int64_t out_offset = output->offset;
// Use output validity buffer as mask to decide what values to copy
uint8_t* out_valid = output->buffers[0]->mutable_data();
// Clear output buffer - no values are set initially
BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
uint8_t* out_values = output->buffers[1]->mutable_data();

for (const auto& datum : batch.values) {
if ((datum.is_scalar() && datum.scalar()->is_valid) ||
(datum.is_array() && !datum.array()->MayHaveNulls())) {
// Valid scalar, or all-valid array
CopyValuesAllValid<Type>(datum, out_valid, out_values, out_offset, batch.length);
break;
} else if (datum.is_array()) {
// Array with nulls
const ArrayData& arr = *datum.array();
const DataType& type = *datum.type();
const uint8_t* in_valid = arr.buffers[0]->data();
const uint8_t* in_values = arr.buffers[1]->data();
BinaryBitBlockCounter counter(in_valid, arr.offset, out_valid, out_offset,
batch.length);
int64_t offset = 0;
while (offset < batch.length) {
const auto block = counter.NextAndNotWord();
if (block.AllSet()) {
CopyValues<Type>(datum, offset, block.length, out_valid, out_values,
out_offset + offset);
} else if (block.popcount) {
for (int64_t j = 0; j < block.length; ++j) {
if (!BitUtil::GetBit(out_valid, out_offset + offset + j) &&
BitUtil::GetBit(in_valid, arr.offset + offset + j)) {
// This version lets us avoid calling MayHaveNulls() on every iteration
// (which does an atomic load and can add up)
CopyOneArrayValue<Type>(type, in_valid, in_values, arr.offset + offset + j,
out_valid, out_values, out_offset + offset + j);
}
}
}
offset += block.length;
}
}
}

// Initialize any remaining null slots (uninitialized memory)
InitializeNullSlots(*out->type(), out_valid, out_values, out_offset, batch.length);
return Status::OK();
}

template <typename Type, typename Enable = void>
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArrayCoalesce<Type>(ctx, batch, out);
}
}
return ExecScalarCoalesce(ctx, batch, out);
}
};

template <>
struct CoalesceFunctor<NullType> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
return Status::OK();
}
};

template <typename Type>
struct CoalesceFunctor<Type, enable_if_base_binary<Type>> {
using offset_type = typename Type::offset_type;
using BuilderType = typename TypeTraits<Type>::BuilderType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
}
}
return ExecScalarCoalesce(ctx, batch, out);
}

static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
// Special case: grab any leading non-null scalar or array arguments
for (const auto& datum : batch.values) {
if (datum.is_scalar()) {
if (!datum.scalar()->is_valid) continue;
ARROW_ASSIGN_OR_RAISE(
*out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool()));
return Status::OK();
} else if (datum.is_array() && !datum.array()->MayHaveNulls()) {
*out = datum;
return Status::OK();
}
break;
}
ArrayData* output = out->mutable_array();
BuilderType builder(batch[0].type(), ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(batch.length));
for (int64_t i = 0; i < batch.length; i++) {
bool set = false;
for (const auto& datum : batch.values) {
if (datum.is_scalar()) {
if (datum.scalar()->is_valid) {
RETURN_NOT_OK(builder.Append(UnboxScalar<Type>::Unbox(*datum.scalar())));
set = true;
break;
}
} else {
const ArrayData& source = *datum.array();
if (!source.MayHaveNulls() ||
BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) {
const uint8_t* data = source.buffers[2]->data();
const offset_type* offsets = source.GetValues<offset_type>(1);
const offset_type offset0 = offsets[i];
const offset_type offset1 = offsets[i + 1];
RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0));
set = true;
break;
}
}
}
if (!set) RETURN_NOT_OK(builder.AppendNull());
}
ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish());
*output = *temp_output->data();
// Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
output->type = batch[0].type();
return Status::OK();
}
};

Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
ValueDescr result = descrs.back();
result.shape = GetBroadcastShape(descrs);
Expand Down Expand Up @@ -1399,6 +1630,25 @@ void AddPrimitiveCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar
}
}

void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(FirstType),
/*is_varargs=*/true),
exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
kernel.can_write_into_slices = is_fixed_width(get_id.id);
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}

void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
const std::vector<std::shared_ptr<DataType>>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticPrimitive<CoalesceFunctor>(*type);
AddCoalesceKernel(scalar_function, type, std::move(exec));
}
}

const FunctionDoc if_else_doc{"Choose values based on a condition",
("`cond` must be a Boolean scalar/ array. \n`left` or "
"`right` must be of the same type scalar/ array.\n"
Expand All @@ -1419,6 +1669,13 @@ const FunctionDoc case_when_doc{
"Essentially, this implements a switch-case or if-else, if-else... "
"statement."),
{"cond", "*cases"}};

const FunctionDoc coalesce_doc{
"Select the first non-null value in each slot",
("Each row of the output will be the value from the first corresponding input "
"for which the value is not null. If all inputs are null in a row, the output "
"will be null."),
{"*values"}};
} // namespace

void RegisterScalarIfElse(FunctionRegistry* registry) {
Expand Down Expand Up @@ -1447,6 +1704,22 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<Decimal256Type>::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
auto func = std::make_shared<CoalesceFunction>(
"coalesce", Arity::VarArgs(/*min_args=*/1), &coalesce_doc);
AddPrimitiveCoalesceKernels(func, NumericTypes());
AddPrimitiveCoalesceKernels(func, TemporalTypes());
AddPrimitiveCoalesceKernels(
func, {boolean(), null(), day_time_interval(), month_interval()});
AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
CoalesceFunctor<FixedSizeBinaryType>::Exec);
AddCoalesceKernel(func, Type::DECIMAL128, CoalesceFunctor<Decimal128Type>::Exec);
AddCoalesceKernel(func, Type::DECIMAL256, CoalesceFunctor<Decimal256Type>::Exec);
for (const auto& ty : BaseBinaryTypes()) {
AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<CoalesceFunctor>(ty));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

} // namespace internal
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,61 @@ static void CaseWhenBench64Contiguous(benchmark::State& state) {
return CaseWhenBenchContiguous<UInt64Type>(state);
}

template <typename Type>
static void CoalesceBench(benchmark::State& state) {
using CType = typename Type::c_type;
auto type = TypeTraits<Type>::type_singleton();

int64_t len = state.range(0);
int64_t offset = state.range(1);

random::RandomArrayGenerator rand(/*seed=*/0);

std::vector<Datum> arguments;
for (int i = 0; i < 4; i++) {
arguments.emplace_back(
rand.ArrayOf(type, len, /*null_probability=*/0.25)->Slice(offset));
}

for (auto _ : state) {
ABORT_NOT_OK(CallFunction("coalesce", arguments));
}

state.SetBytesProcessed(state.iterations() * arguments.size() * (len - offset) *
sizeof(CType));
}

template <typename Type>
static void CoalesceNonNullBench(benchmark::State& state) {
using CType = typename Type::c_type;
auto type = TypeTraits<Type>::type_singleton();

int64_t len = state.range(0);
int64_t offset = state.range(1);

random::RandomArrayGenerator rand(/*seed=*/0);

std::vector<Datum> arguments;
arguments.emplace_back(
rand.ArrayOf(type, len, /*null_probability=*/0.25)->Slice(offset));
arguments.emplace_back(rand.ArrayOf(type, len, /*null_probability=*/0)->Slice(offset));

for (auto _ : state) {
ABORT_NOT_OK(CallFunction("coalesce", arguments));
}

state.SetBytesProcessed(state.iterations() * arguments.size() * (len - offset) *
sizeof(CType));
}

static void CoalesceBench64(benchmark::State& state) {
return CoalesceBench<Int64Type>(state);
}

static void CoalesceNonNullBench64(benchmark::State& state) {
return CoalesceBench<Int64Type>(state);
}

BENCHMARK(IfElseBench32)->Args({elems, 0});
BENCHMARK(IfElseBench64)->Args({elems, 0});

Expand All @@ -251,5 +306,11 @@ BENCHMARK(CaseWhenBench64)->Args({elems, 99});
BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0});
BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99});

BENCHMARK(CoalesceBench64)->Args({elems, 0});
BENCHMARK(CoalesceBench64)->Args({elems, 99});

BENCHMARK(CoalesceNonNullBench64)->Args({elems, 0});
BENCHMARK(CoalesceNonNullBench64)->Args({elems, 99});

} // namespace compute
} // namespace arrow
Loading