Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c_glib/test/test-dense-union-scalar.rb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_equal
end

def test_to_s
assert_equal("...", @scalar.to_s)
assert_equal("union{number: int8 = -29}", @scalar.to_s)
end

def test_value
Expand Down
2 changes: 1 addition & 1 deletion c_glib/test/test-sparse-union-scalar.rb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_equal
end

def test_to_s
assert_equal("...", @scalar.to_s)
assert_equal("union{number: int8 = -29}", @scalar.to_s)
end

def test_value
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,8 @@ add_arrow_compute_test(aggregate_test
hash_aggregate_test.cc
test_util.cc)
add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute")

# ----------------------------------------------------------------------
# Utilities

add_arrow_compute_test(kernel_utility_test SOURCES codegen_internal_test.cc)
102 changes: 95 additions & 7 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,14 @@ void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs) {

void ReplaceTypes(const std::shared_ptr<DataType>& type,
std::vector<ValueDescr>* descrs) {
for (auto& descr : *descrs) {
descr.type = type;
ReplaceTypes(type, descrs->data(), descrs->size());
}

void ReplaceTypes(const std::shared_ptr<DataType>& type, ValueDescr* begin,
size_t count) {
auto* end = begin + count;
for (auto* it = begin; it != end; it++) {
it->type = type;
}
}

Expand Down Expand Up @@ -158,26 +164,47 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}

std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& descrs) {
std::shared_ptr<DataType> CommonTemporal(const std::vector<ValueDescr>& descrs) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const std::string* timezone = nullptr;
bool saw_date32 = false;
bool saw_date64 = false;

for (const auto& descr : descrs) {
auto id = descr.type->id();
// a common timestamp is only possible if all types are timestamp like
switch (id) {
case Type::DATE32:
// Date32's unit is days, but the coarsest we have is seconds
saw_date32 = true;
continue;
case Type::DATE64:
finest_unit = std::max(finest_unit, TimeUnit::MILLI);
saw_date64 = true;
continue;
case Type::TIMESTAMP:
finest_unit =
std::max(finest_unit, checked_cast<const TimestampType&>(*descr.type).unit());
case Type::TIMESTAMP: {
const auto& ty = checked_cast<const TimestampType&>(*descr.type);
// Don't cast to common timezone by default (may not make
// sense for all kernels)
if (timezone && *timezone != ty.timezone()) return nullptr;
timezone = &ty.timezone();
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
default:
return nullptr;
}
}

return timestamp(finest_unit);
if (timezone) {
// At least one timestamp seen
return timestamp(finest_unit, *timezone);
} else if (saw_date64) {
return date64();
} else if (saw_date32) {
return date32();
}
return nullptr;
}

std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs) {
Expand Down Expand Up @@ -290,6 +317,67 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion,
return Status::OK();
}

Status CastDecimalArgs(ValueDescr* begin, size_t count) {
Type::type casted_type_id = Type::DECIMAL128;
auto* end = begin + count;

int32_t max_scale = 0;
bool any_floating = false;
for (auto* it = begin; it != end; ++it) {
const auto& ty = *it->type;
if (is_floating(ty.id())) {
// Decimal + float = float
any_floating = true;
} else if (is_integer(ty.id())) {
// Nothing to do here
} else if (is_decimal(ty.id())) {
max_scale = std::max(max_scale, checked_cast<const DecimalType&>(ty).scale());
if (ty.id() == Type::DECIMAL256) {
casted_type_id = Type::DECIMAL256;
}
} else {
// Non-numeric, can't cast
return Status::OK();
}
}
if (any_floating) {
ReplaceTypes(float64(), begin, count);
return Status::OK();
}

// All integer and decimal, rescale
int32_t common_precision = 0;
for (auto* it = begin; it != end; ++it) {
const auto& ty = *it->type;
if (is_integer(ty.id())) {
ARROW_ASSIGN_OR_RAISE(auto precision, MaxDecimalDigitsForInteger(ty.id()));
precision += max_scale;
common_precision = std::max(common_precision, precision);
} else if (is_decimal(ty.id())) {
const auto& decimal_ty = checked_cast<const DecimalType&>(ty);
auto precision = decimal_ty.precision();
const auto scale = decimal_ty.scale();
precision += max_scale - scale;
common_precision = std::max(common_precision, precision);
}
}

if (common_precision > BasicDecimal256::kMaxPrecision) {
return Status::Invalid("Result precision (", common_precision,
") exceeds max precision of Decimal256 (",
BasicDecimal256::kMaxPrecision, ")");
} else if (common_precision > BasicDecimal128::kMaxPrecision) {
casted_type_id = Type::DECIMAL256;
}

for (auto* it = begin; it != end; ++it) {
ARROW_ASSIGN_OR_RAISE(it->type,
DecimalType::Make(casted_type_id, common_precision, max_scale));
}

return Status::OK();
}

bool HasDecimal(const std::vector<ValueDescr>& descrs) {
for (const auto& descr : descrs) {
if (is_decimal(descr.type->id())) {
Expand Down
13 changes: 12 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1279,14 +1279,17 @@ void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs);
ARROW_EXPORT
void ReplaceTypes(const std::shared_ptr<DataType>&, std::vector<ValueDescr>* descrs);

ARROW_EXPORT
void ReplaceTypes(const std::shared_ptr<DataType>&, ValueDescr* descrs, size_t count);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a reminder that we'd like a std::span backport at some point ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


ARROW_EXPORT
std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs);

ARROW_EXPORT
std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count);

ARROW_EXPORT
std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& descrs);
std::shared_ptr<DataType> CommonTemporal(const std::vector<ValueDescr>& descrs);

ARROW_EXPORT
std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs);
Expand All @@ -1298,9 +1301,17 @@ enum class DecimalPromotion : uint8_t {
kDivide,
};

/// Given two arguments, at least one of which is decimal, promote all
/// to not necessarily identical types, but types which are compatible
/// for the given operator (add/multiply/divide).
ARROW_EXPORT
Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<ValueDescr>* descrs);

/// Given one or more arguments, at least one of which is decimal,
/// promote all to an identical type.
ARROW_EXPORT
Status CastDecimalArgs(ValueDescr* begin, size_t count);

ARROW_EXPORT
bool HasDecimal(const std::vector<ValueDescr>& descrs);

Expand Down
155 changes: 155 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/type.h"
#include "arrow/type_fwd.h"

namespace arrow {
namespace compute {
namespace internal {

TEST(TestDispatchBest, CastBinaryDecimalArgs) {
std::vector<ValueDescr> args;
std::vector<DecimalPromotion> modes = {
DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide};

// Any float -> all float
for (auto mode : modes) {
args = {decimal128(3, 2), float64()};
ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
AssertTypeEqual(args[0].type, float64());
AssertTypeEqual(args[1].type, float64());
}

// Integer -> decimal with common scale
args = {decimal128(1, 0), int64()};
ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
AssertTypeEqual(args[0].type, decimal128(1, 0));
AssertTypeEqual(args[1].type, decimal128(19, 0));

// Add: rescale so all have common scale
args = {decimal128(3, 2), decimal128(3, -2)};
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"),
CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
}

TEST(TestDispatchBest, CastDecimalArgs) {
std::vector<ValueDescr> args;

// Any float -> all float
args = {decimal128(3, 2), float64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, float64());
AssertTypeEqual(args[1].type, float64());

args = {float32(), float64(), decimal128(3, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, float64());
AssertTypeEqual(args[1].type, float64());
AssertTypeEqual(args[2].type, float64());

// Promote to common decimal width
args = {decimal128(3, 2), decimal256(3, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal256(3, 2));
AssertTypeEqual(args[1].type, decimal256(3, 2));

// Rescale so all have common scale/precision
args = {decimal128(3, 2), decimal128(3, 0)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal128(5, 2));
AssertTypeEqual(args[1].type, decimal128(5, 2));

args = {decimal128(3, 2), decimal128(3, -2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal128(7, 2));
AssertTypeEqual(args[1].type, decimal128(7, 2));

args = {decimal128(3, 0), decimal128(3, 1), decimal128(3, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal128(5, 2));
AssertTypeEqual(args[1].type, decimal128(5, 2));
AssertTypeEqual(args[2].type, decimal128(5, 2));

// Integer -> decimal with appropriate precision
args = {decimal128(3, 0), int64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal128(19, 0));
AssertTypeEqual(args[1].type, decimal128(19, 0));

args = {decimal128(3, 1), int64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal128(20, 1));
AssertTypeEqual(args[1].type, decimal128(20, 1));

args = {decimal128(3, -1), int64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal128(19, 0));
AssertTypeEqual(args[1].type, decimal128(19, 0));

// Overflow decimal128 max precision -> promote to decimal256
args = {decimal128(38, 0), decimal128(37, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal256(40, 2));
AssertTypeEqual(args[1].type, decimal256(40, 2));

// Overflow decimal256 max precision
args = {decimal256(76, 0), decimal256(75, 1)};
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr(
"Result precision (77) exceeds max precision of Decimal256 (76)"),
CastDecimalArgs(args.data(), args.size()));

// Incompatible, no cast
args = {decimal256(3, 2), float64(), utf8()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal256(3, 2));
AssertTypeEqual(args[1].type, float64());
AssertTypeEqual(args[2].type, utf8());
}

TEST(TestDispatchBest, CommonTemporal) {
AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal({timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::NANO)}));
AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
CommonTemporal({timestamp(TimeUnit::SECOND, "UTC"),
timestamp(TimeUnit::NANO, "UTC")}));
AssertTypeEqual(timestamp(TimeUnit::NANO),
CommonTemporal({date32(), timestamp(TimeUnit::NANO)}));
AssertTypeEqual(timestamp(TimeUnit::MILLI),
CommonTemporal({date64(), timestamp(TimeUnit::SECOND)}));
AssertTypeEqual(date32(), CommonTemporal({date32(), date32()}));
AssertTypeEqual(date64(), CommonTemporal({date64(), date64()}));
AssertTypeEqual(date64(), CommonTemporal({date32(), date64()}));
ASSERT_EQ(nullptr, CommonTemporal({}));
ASSERT_EQ(nullptr, CommonTemporal({float64(), int32()}));
ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND, "UTC")}));
ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"),
timestamp(TimeUnit::SECOND, "UTC")}));
}

} // namespace internal
} // namespace compute
} // namespace arrow
Loading