Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ Result<Datum> IfElse(const Datum& cond, const Datum& if_true, const Datum& if_fa
return CallFunction("if_else", {cond, if_true, if_false}, ctx);
}

Result<Datum> CaseWhen(const Datum& cond, const std::vector<Datum>& cases,
ExecContext* ctx) {
std::vector<Datum> args = {cond};
args.reserve(cases.size() + 1);
args.insert(args.end(), cases.begin(), cases.end());
return CallFunction("case_when", args, ctx);
}

// ----------------------------------------------------------------------
// Temporal functions

Expand Down
17 changes: 17 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,23 @@ ARROW_EXPORT
Result<Datum> IfElse(const Datum& cond, const Datum& left, const Datum& right,
ExecContext* ctx = NULLPTR);

/// \brief CaseWhen behaves like a switch/case or if-else if-else statement: for
/// each row, select the first value for which the corresponding condition is
/// true, or (if given) select the 'else' value, else emit null. Note that a
/// null condition is the same as false.
///
/// \param[in] cond Conditions (Boolean)
/// \param[in] cases Values (any type), along with an optional 'else' value.
/// \param[in] ctx the function execution context, optional
///
/// \return the resulting datum
///
/// \since 5.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> CaseWhen(const Datum& cond, const std::vector<Datum>& cases,
ExecContext* ctx = NULLPTR);

/// \brief Year returns year for each element of `values`
///
/// \param[in] values input to extract year from
Expand Down
23 changes: 13 additions & 10 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,7 @@ KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType out
out_type_(std::move(out_type)),
is_varargs_(is_varargs),
hash_code_(0) {
// VarArgs sigs must have only a single input type to use for argument validation
DCHECK(!is_varargs || (is_varargs && (in_types_.size() == 1)));
DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1)));
}

std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> in_types,
Expand All @@ -430,8 +429,8 @@ bool KernelSignature::Equals(const KernelSignature& other) const {

bool KernelSignature::MatchesInputs(const std::vector<ValueDescr>& args) const {
if (is_varargs_) {
for (const auto& arg : args) {
if (!in_types_[0].Matches(arg)) {
for (size_t i = 0; i < args.size(); ++i) {
if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) {
return false;
}
}
Expand Down Expand Up @@ -464,15 +463,19 @@ std::string KernelSignature::ToString() const {
std::stringstream ss;

if (is_varargs_) {
ss << "varargs[" << in_types_[0].ToString() << "]";
ss << "varargs[";
} else {
ss << "(";
for (size_t i = 0; i < in_types_.size(); ++i) {
if (i > 0) {
ss << ", ";
}
ss << in_types_[i].ToString();
}
for (size_t i = 0; i < in_types_.size(); ++i) {
if (i > 0) {
ss << ", ";
}
ss << in_types_[i].ToString();
}
if (is_varargs_) {
ss << "]";
} else {
ss << ")";
}
ss << " -> " << out_type_.ToString();
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,10 @@ class ARROW_EXPORT OutputType {

/// \brief Holds the input types and output type of the kernel.
///
/// VarArgs functions should pass a single input type to be used to validate
/// the input types of a function invocation.
/// VarArgs functions with minimum N arguments should pass up to N input types to be
/// used to validate the input types of a function invocation. The first N-1 types
/// will be matched against the first N-1 arguments, and the last type will be
/// matched against the remaining arguments.
class ARROW_EXPORT KernelSignature {
public:
KernelSignature(std::vector<InputType> in_types, OutputType out_type,
Expand Down
31 changes: 22 additions & 9 deletions cpp/src/arrow/compute/kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,15 +468,28 @@ TEST(KernelSignature, MatchesInputs) {
}

TEST(KernelSignature, VarArgsMatchesInputs) {
KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true);

std::vector<ValueDescr> args = {int8()};
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(ValueDescr::Scalar(int8()));
args.push_back(ValueDescr::Array(int8()));
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(int32());
ASSERT_FALSE(sig.MatchesInputs(args));
{
KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true);

std::vector<ValueDescr> args = {int8()};
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(ValueDescr::Scalar(int8()));
args.push_back(ValueDescr::Array(int8()));
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(int32());
ASSERT_FALSE(sig.MatchesInputs(args));
}
{
KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working this out


std::vector<ValueDescr> args = {int8()};
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(ValueDescr::Scalar(utf8()));
args.push_back(ValueDescr::Array(utf8()));
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(int32());
ASSERT_FALSE(sig.MatchesInputs(args));
}
}

TEST(KernelSignature, ToString) {
Expand Down
20 changes: 14 additions & 6 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,14 @@ void ReplaceTypes(const std::shared_ptr<DataType>& type,
}

std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs) {
DCHECK(!descrs.empty()) << "tried to find CommonNumeric type of an empty set";
return CommonNumeric(descrs.data(), descrs.size());
}

for (const auto& descr : descrs) {
std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
DCHECK_GT(count, 0) << "tried to find CommonNumeric type of an empty set";

for (size_t i = 0; i < count; i++) {
const auto& descr = *(begin + i);
auto id = descr.type->id();
if (!is_floating(id) && !is_integer(id)) {
// a common numeric type is only possible if all types are numeric
Expand All @@ -232,17 +237,20 @@ std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs) {
}
}

for (const auto& descr : descrs) {
for (size_t i = 0; i < count; i++) {
const auto& descr = *(begin + i);
if (descr.type->id() == Type::DOUBLE) return float64();
}

for (const auto& descr : descrs) {
for (size_t i = 0; i < count; i++) {
const auto& descr = *(begin + i);
if (descr.type->id() == Type::FLOAT) return float32();
}

int max_width_signed = 0, max_width_unsigned = 0;

for (const auto& descr : descrs) {
for (size_t i = 0; i < count; i++) {
const auto& descr = *(begin + i);
auto id = descr.type->id();
auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned);
*max_width = std::max(bit_width(id), *max_width);
Expand All @@ -253,7 +261,7 @@ std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs) {
if (max_width_unsigned == 32) return uint32();
if (max_width_unsigned == 16) return uint16();
DCHECK_EQ(max_width_unsigned, 8);
return int8();
return uint8();
}

if (max_width_signed <= max_width_unsigned) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,9 @@ void ReplaceTypes(const std::shared_ptr<DataType>&, std::vector<ValueDescr>* des
ARROW_EXPORT
std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs);

ARROW_EXPORT
std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count);

ARROW_EXPORT
std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& descrs);

Expand Down
Loading