Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
f30cf18
dictionary min max init
R-JunmingChen Aug 10, 2023
422a02f
fix bug
R-JunmingChen Aug 10, 2023
1a5da63
fix output type
R-JunmingChen Aug 10, 2023
89571ad
lint
R-JunmingChen Aug 11, 2023
24f5d95
add correct dictionary value type
R-JunmingChen Aug 11, 2023
5917cd2
lint
R-JunmingChen Aug 11, 2023
e9f4e5a
delete
R-JunmingChen Aug 11, 2023
7752011
unify out_type
R-JunmingChen Aug 11, 2023
640b17c
lint
R-JunmingChen Aug 11, 2023
d35f3b9
add one test
R-JunmingChen Aug 11, 2023
f205ade
type
R-JunmingChen Aug 12, 2023
b034d5d
basic test
R-JunmingChen Aug 12, 2023
649d69f
add decimal test
R-JunmingChen Aug 12, 2023
aa20a10
lint
R-JunmingChen Aug 12, 2023
0703113
optimize code
R-JunmingChen Aug 13, 2023
5d2f51a
simplify
R-JunmingChen Aug 13, 2023
2176153
lint
R-JunmingChen Aug 13, 2023
8452aa6
format
R-JunmingChen Aug 13, 2023
20af859
space
R-JunmingChen Aug 13, 2023
2a21b62
add boolean test
R-JunmingChen Aug 13, 2023
459609b
update
R-JunmingChen Aug 20, 2023
e30d33d
add a empty test
R-JunmingChen Aug 20, 2023
e64fd4e
Merge branch 'main' of https://github.com/R-JunmingChen/arrow into AR…
R-JunmingChen Aug 21, 2023
5cdb5b6
move and rename
R-JunmingChen Aug 22, 2023
84dd9ea
delete log
R-JunmingChen Aug 22, 2023
f458569
matcher
R-JunmingChen Aug 23, 2023
2b17751
Merge branch 'main' of https://github.com/R-JunmingChen/arrow into AR…
R-JunmingChen Oct 13, 2023
6a108a3
dictionary comaction
R-JunmingChen Oct 13, 2023
5859bc9
lint
R-JunmingChen Oct 13, 2023
0b23b35
Update cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
R-JunmingChen Oct 15, 2023
2020d48
fix comment
R-JunmingChen Oct 15, 2023
b84486f
Merge branch 'ARROW-36831' of https://github.com/R-JunmingChen/arrow …
R-JunmingChen Oct 15, 2023
1594f44
fix bug
R-JunmingChen Oct 15, 2023
7c2ebe0
fix bug
R-JunmingChen Oct 15, 2023
bc5a647
lint
R-JunmingChen Oct 15, 2023
2623472
add float test
R-JunmingChen Oct 15, 2023
5599daa
add an interger test
R-JunmingChen Oct 15, 2023
38bb4d4
lint
R-JunmingChen Oct 15, 2023
208e8bd
add chunk test
R-JunmingChen Oct 16, 2023
656d4e2
Update cpp/src/arrow/compute/kernels/aggregate_basic.cc
R-JunmingChen Oct 16, 2023
dbf3cbe
Update cpp/src/arrow/compute/kernels/aggregate_basic.cc
R-JunmingChen Oct 16, 2023
e07f261
optimization
R-JunmingChen Oct 16, 2023
6f33e7b
Merge branch 'ARROW-36831' of https://github.com/R-JunmingChen/arrow …
R-JunmingChen Oct 16, 2023
b8f10db
lint
R-JunmingChen Oct 16, 2023
0a87506
add binary test
R-JunmingChen Oct 16, 2023
6979e48
lint
R-JunmingChen Oct 16, 2023
5e3f9a4
boolean test
R-JunmingChen Oct 16, 2023
4cb58bc
delete test
R-JunmingChen Oct 17, 2023
af93cbf
Update cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
R-JunmingChen Oct 19, 2023
be9238e
Update cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
R-JunmingChen Oct 19, 2023
d2ed61c
Update cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
R-JunmingChen Oct 19, 2023
9c2174c
Update cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
R-JunmingChen Oct 19, 2023
d486422
kernel member init
R-JunmingChen Oct 22, 2023
bccbc81
Merge branch 'main' of https://github.com/R-JunmingChen/arrow into AR…
R-JunmingChen Oct 22, 2023
3a58ab6
debug value state
R-JunmingChen Oct 25, 2023
a9ff8b3
format
R-JunmingChen Oct 25, 2023
ccdc4af
delete io
R-JunmingChen Oct 25, 2023
901484c
rename
R-JunmingChen Oct 25, 2023
f9b025f
reset
R-JunmingChen Nov 6, 2023
6a53df0
Merge branch 'main' of https://github.com/R-JunmingChen/arrow into AR…
R-JunmingChen Nov 28, 2023
7e71603
CheckDictionaryMinMax
R-JunmingChen Nov 29, 2023
7655122
logic null count
R-JunmingChen Nov 29, 2023
936924e
add test
R-JunmingChen Nov 30, 2023
05c85f1
lint
R-JunmingChen Nov 30, 2023
9260e2f
add a comment
R-JunmingChen Feb 4, 2024
0e53afe
local var
R-JunmingChen Feb 4, 2024
42b1cd3
Merge branch 'main' of https://github.com/R-JunmingChen/arrow into AR…
R-JunmingChen Apr 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,37 @@ std::shared_ptr<TypeMatcher> RunEndEncoded(
std::move(value_type_matcher));
}

class NotMatcher : public TypeMatcher {
public:
explicit NotMatcher(std::shared_ptr<TypeMatcher> base_matcher)
: base_matcher{std::move(base_matcher)} {}

~NotMatcher() override = default;

bool Matches(const DataType& type) const override {
return !base_matcher->Matches(type);
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
const auto* casted = dynamic_cast<const NotMatcher*>(&other);
return casted != nullptr && base_matcher->Equals(*casted->base_matcher);
}

std::string ToString() const override {
return "not(" + base_matcher->ToString() + ")";
};

private:
std::shared_ptr<TypeMatcher> base_matcher;
};

std::shared_ptr<TypeMatcher> Not(std::shared_ptr<TypeMatcher> base_matcher) {
return std::make_shared<NotMatcher>(std::move(base_matcher));
}

} // namespace match

// ----------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ ARROW_EXPORT std::shared_ptr<TypeMatcher> RunEndEncoded(
std::shared_ptr<TypeMatcher> run_end_type_matcher,
std::shared_ptr<TypeMatcher> value_type_matcher);

/// \brief Match types that the base_matcher doesn't match
///
/// @param[in] base_matcher a matcher used to negation match
ARROW_EXPORT std::shared_ptr<TypeMatcher> Not(std::shared_ptr<TypeMatcher> base_matcher);

} // namespace match

/// \brief An object used for type-checking arguments to be passed to a kernel
Expand Down
34 changes: 28 additions & 6 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ void AddFirstOrLastAggKernel(ScalarAggregateFunction* func,
// ----------------------------------------------------------------------
// MinMax implementation

using arrow::compute::match::Not;
using arrow::compute::match::SameTypeId;

Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
const KernelInitArgs& args) {
ARROW_ASSIGN_OR_RAISE(TypeHolder out_type,
Expand All @@ -494,9 +497,10 @@ Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,

// For "min" and "max" functions: override finalize and return the actual value
template <MinOrMax min_or_max>
void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,
ScalarAggregateFunction* min_max_func) {
auto sig = KernelSignature::Make({InputType::Any()}, FirstType);
void AddMinOrMaxAggKernels(ScalarAggregateFunction* func,
ScalarAggregateFunction* min_max_func) {
std::shared_ptr<arrow::compute::KernelSignature> sig =
KernelSignature::Make({InputType(Not(SameTypeId(Type::DICTIONARY)))}, FirstType);
auto init = [min_max_func](
KernelContext* ctx,
const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
Expand All @@ -516,6 +520,9 @@ void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,

// Note SIMD level is always NONE, but the convenience kernel will
// dispatch to an appropriate implementation
AddAggKernel(std::move(sig), init, finalize, func);

sig = KernelSignature::Make({InputType(Type::DICTIONARY)}, DictionaryValueType);
AddAggKernel(std::move(sig), std::move(init), std::move(finalize), func);
}

Expand Down Expand Up @@ -873,6 +880,15 @@ Result<TypeHolder> MinMaxType(KernelContext*, const std::vector<TypeHolder>& typ
return struct_({field("min", ty), field("max", ty)});
}

Result<TypeHolder> DictionaryMinMaxType(KernelContext*,
const std::vector<TypeHolder>& types) {
// T -> struct<min: T.value_type, max: T.value_type>
auto ty = types.front();
const DictionaryType& ty_dict = checked_cast<const DictionaryType&>(*ty);
return struct_(
{field("min", ty_dict.value_type()), field("max", ty_dict.value_type())});
}

} // namespace

Result<TypeHolder> FirstLastType(KernelContext*, const std::vector<TypeHolder>& types) {
Expand All @@ -896,7 +912,12 @@ void AddFirstLastKernels(KernelInit init,

void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id,
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
auto sig = KernelSignature::Make({InputType(get_id.id)}, MinMaxType);
std::shared_ptr<arrow::compute::KernelSignature> sig;
if (get_id.id == Type::DICTIONARY) {
sig = KernelSignature::Make({InputType(get_id.id)}, DictionaryMinMaxType);
} else {
sig = KernelSignature::Make({InputType(get_id.id)}, MinMaxType);
}
AddAggKernel(std::move(sig), init, func, simd_level);
}

Expand Down Expand Up @@ -1118,6 +1139,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddMinMaxKernels(MinMaxInit, NumericTypes(), func.get());
AddMinMaxKernels(MinMaxInit, TemporalTypes(), func.get());
AddMinMaxKernels(MinMaxInit, BaseBinaryTypes(), func.get());
AddMinMaxKernel(MinMaxInit, Type::DICTIONARY, func.get());
AddMinMaxKernel(MinMaxInit, Type::FIXED_SIZE_BINARY, func.get());
AddMinMaxKernel(MinMaxInit, Type::INTERVAL_MONTHS, func.get());
AddMinMaxKernel(MinMaxInit, Type::DECIMAL128, func.get());
Expand All @@ -1140,12 +1162,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
// Add min/max as convenience functions
func = std::make_shared<ScalarAggregateFunction>("min", Arity::Unary(), min_or_max_doc,
&default_scalar_aggregate_options);
AddMinOrMaxAggKernel<MinOrMax::Min>(func.get(), min_max_func);
AddMinOrMaxAggKernels<MinOrMax::Min>(func.get(), min_max_func);
DCHECK_OK(registry->AddFunction(std::move(func)));

func = std::make_shared<ScalarAggregateFunction>("max", Arity::Unary(), min_or_max_doc,
&default_scalar_aggregate_options);
AddMinOrMaxAggKernel<MinOrMax::Max>(func.get(), min_max_func);
AddMinOrMaxAggKernels<MinOrMax::Max>(func.get(), min_max_func);
DCHECK_OK(registry->AddFunction(std::move(func)));

func = std::make_shared<ScalarAggregateFunction>("product", Arity::Unary(), product_doc,
Expand Down
118 changes: 118 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,11 @@ struct FirstLastInitState {
}
};

template <SimdLevel::type SimdLevel>
std::unique_ptr<KernelState> DictionaryMinMaxImplFunc(const DataType& in_type,
std::shared_ptr<DataType> out_type,
ScalarAggregateOptions options);

template <SimdLevel::type SimdLevel>
struct MinMaxInitState {
std::unique_ptr<KernelState> state;
Expand Down Expand Up @@ -1002,6 +1007,11 @@ struct MinMaxInitState {
return Status::OK();
}

Status Visit(const DictionaryType&) {
state = DictionaryMinMaxImplFunc<SimdLevel>(in_type, out_type, options);
return Status::OK();
}

template <typename Type>
enable_if_physical_integer<Type, Status> Visit(const Type&) {
using PhysicalType = typename Type::PhysicalType;
Expand Down Expand Up @@ -1033,4 +1043,112 @@ struct MinMaxInitState {
}
};

template <SimdLevel::type SimdLevel>
struct DictionaryMinMaxImpl : public ScalarAggregator {
using ThisType = DictionaryMinMaxImpl<SimdLevel>;

DictionaryMinMaxImpl(const DataType& in_type, std::shared_ptr<DataType> out_type,
ScalarAggregateOptions options)
: options(std::move(options)),
out_type(std::move(out_type)),
has_nulls(false),
count(0),
value_type(checked_cast<const DictionaryType&>(in_type).value_type()),
value_state(nullptr) {
this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
}

Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
if (batch[0].is_scalar()) {
return Status::NotImplemented("No min/max implemented for DictionaryScalar");
}
RETURN_NOT_OK(this->InitValueState());

// The minmax is computed from dictionay values, in case some values are not
// referenced by indices, a compaction needs to be excuted here.
DictionaryArray dict_arr(batch[0].array.ToArrayData());
ARROW_ASSIGN_OR_RAISE(auto compacted_arr, dict_arr.Compact(ctx->memory_pool()));
const DictionaryArray& compacted_dict_arr =
checked_cast<const DictionaryArray&>(*compacted_arr);
const int64_t null_count = compacted_dict_arr.ComputeLogicalNullCount();
const int64_t non_null_count = compacted_dict_arr.length() - null_count;

this->has_nulls |= null_count > 0;
this->count += non_null_count;
if ((this->has_nulls && !options.skip_nulls) || (non_null_count == 0)) {
return Status::OK();
}

const ArrayData& dict_data =
checked_cast<const ArrayData&>(*compacted_dict_arr.dictionary()->data());
RETURN_NOT_OK(
checked_cast<ScalarAggregator*>(this->value_state.get())
->Consume(nullptr, ExecSpan(std::vector({ExecValue(dict_data)}), 1)));
return Status::OK();
}

Status MergeFrom(KernelContext*, KernelState&& src) override {
auto&& other = checked_cast<ThisType&&>(src);
this->has_nulls |= other.has_nulls;
this->count += other.count;
if ((this->has_nulls && !options.skip_nulls) || other.value_state == nullptr) {
return Status::OK();
}

if (this->value_state == nullptr) {
this->value_state.reset(other.value_state.release());
} else {
RETURN_NOT_OK(checked_cast<ScalarAggregator*>(this->value_state.get())
->MergeFrom(nullptr, std::move(*other.value_state)));
}
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override {
if ((this->has_nulls && !options.skip_nulls) || (this->count < options.min_count) ||
this->value_state.get() == nullptr) {
const auto& struct_type = checked_cast<const StructType&>(*out_type);
const auto& child_type = struct_type.field(0)->type();

std::shared_ptr<Scalar> null_scalar = MakeNullScalar(child_type);
std::vector<std::shared_ptr<Scalar>> values = {null_scalar, null_scalar};
out->value = std::make_shared<StructScalar>(std::move(values), this->out_type);
} else {
Datum temp;
RETURN_NOT_OK(checked_cast<ScalarAggregator*>(this->value_state.get())
->Finalize(nullptr, &temp));
const auto& result = temp.scalar_as<StructScalar>();
DCHECK(result.is_valid);
out->value = result.GetSharedPtr();
}
return Status::OK();
}

ScalarAggregateOptions options;
std::shared_ptr<DataType> out_type;
bool has_nulls;
int64_t count;
std::shared_ptr<DataType> value_type;
std::unique_ptr<KernelState> value_state;

private:
inline Status InitValueState() {
if (this->value_state == nullptr) {
const DataType& value_type_ref = checked_cast<const DataType&>(*this->value_type);
ScalarAggregateOptions options = ScalarAggregateOptions::Defaults();
MinMaxInitState<SimdLevel::NONE> valueMinMaxInitState(nullptr, value_type_ref,
out_type, options);
ARROW_ASSIGN_OR_RAISE(this->value_state, valueMinMaxInitState.Create());
}
return Status::OK();
}
};

template <SimdLevel::type SimdLevel>
std::unique_ptr<KernelState> DictionaryMinMaxImplFunc(const DataType& in_type,
std::shared_ptr<DataType> out_type,
ScalarAggregateOptions options) {
return std::make_unique<DictionaryMinMaxImpl<SimdLevel>>(in_type, out_type, options);
}

} // namespace arrow::compute::internal
Loading