Skip to content

Commit 3179cdb

Browse files
westonpacebkietz
andcommitted
ARROW-17966: [C++] Adjust to new format for Substrait optional arguments (apache#15)
* ARROW-17966: Updated to latest Substrait version. Switched from optional enum args to proper options. Added check for minimum Substrait version * ARROW-17966: Add version to python substrait examples. Fix version handling to check major version and not just minor * ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com> * ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com> * ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com> * ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com> * ARROW-17966: Display the available choices when a user enters a valid substrait option that Acero doesn't support * ARROW-17966: Simplify parsing boilerplate per review comments * ARROW-17966: Gracefully error if the user does not supply any preferences for an option * ARROW-17966: Prefer range loops where possible * ARROW-17966: Rebase cleanup * ARROW-17966: Minor fix to failing unit tests: remove enum="unspecified" * ARROW-17966: Minor lint fix * ARROW-17966: Cmake format Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>
1 parent f46e0ed commit 3179cdb

File tree

11 files changed

+374
-126
lines changed

11 files changed

+374
-126
lines changed

cpp/cmake_modules/ThirdpartyToolchain.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,13 @@ else()
657657
"${THIRDPARTY_MIRROR_URL}/snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz")
658658
endif()
659659

660+
# Remove these two lines once https://github.com/substrait-io/substrait/pull/342 merges
661+
set(ENV{ARROW_SUBSTRAIT_URL}
662+
"https://github.com/substrait-io/substrait/archive/e59008b6b202f8af06c2266991161b1e45cb056a.tar.gz"
663+
)
664+
set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM
665+
"f64629cb377fcc62c9d3e8fe69fa6a4cf326f34d756e03db84843c5cce8d04cd")
666+
660667
if(DEFINED ENV{ARROW_SUBSTRAIT_URL})
661668
set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}")
662669
else()

cpp/src/arrow/engine/substrait/expression_internal.cc

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,15 @@ Id NormalizeFunctionName(Id id) {
5252

5353
} // namespace
5454

55-
Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
56-
SubstraitCall* call, const ExtensionSet& ext_set,
55+
Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call,
56+
const ExtensionSet& ext_set,
5757
const ConversionOptions& conversion_options) {
5858
if (arg.has_enum_()) {
5959
const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
6060
switch (enum_val.enum_kind_case()) {
6161
case substrait::FunctionArgument::Enum::EnumKindCase::kSpecified:
6262
call->SetEnumArg(idx, enum_val.specified());
6363
break;
64-
case substrait::FunctionArgument::Enum::EnumKindCase::kUnspecified:
65-
call->SetEnumArg(idx, std::nullopt);
66-
break;
6764
default:
6865
return Status::Invalid("Unrecognized enum kind case: ",
6966
enum_val.enum_kind_case());
@@ -80,15 +77,31 @@ Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
8077
return Status::OK();
8178
}
8279

80+
Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) {
81+
std::vector<std::string_view> prefs;
82+
if (opt.preference_size() == 0) {
83+
return Status::Invalid("Invalid Substrait plan. The option ", opt.name(),
84+
" is specified but does not list any choices");
85+
}
86+
for (const auto& preference : opt.preference()) {
87+
prefs.push_back(preference);
88+
}
89+
call->SetOption(opt.name(), prefs);
90+
return Status::OK();
91+
}
92+
8393
Result<SubstraitCall> DecodeScalarFunction(
8494
Id id, const substrait::Expression::ScalarFunction& scalar_fn,
8595
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
8696
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
8797
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
8898
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
8999
for (int i = 0; i < scalar_fn.arguments_size(); i++) {
90-
ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast<uint32_t>(i), &call,
91-
ext_set, conversion_options));
100+
ARROW_RETURN_NOT_OK(
101+
DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options));
102+
}
103+
for (const auto& opt : scalar_fn.options()) {
104+
ARROW_RETURN_NOT_OK(DecodeOption(opt, &call));
92105
}
93106
return std::move(call);
94107
}
@@ -926,16 +939,12 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
926939
ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
927940
scalar_fn->set_allocated_output_type(output_type.release());
928941

929-
for (uint32_t i = 0; i < call.size(); i++) {
942+
for (int i = 0; i < call.size(); i++) {
930943
substrait::FunctionArgument* arg = scalar_fn->add_arguments();
931944
if (call.HasEnumArg(i)) {
932945
auto enum_val = std::make_unique<substrait::FunctionArgument::Enum>();
933-
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg, call.GetEnumArg(i));
934-
if (enum_arg) {
935-
enum_val->set_specified(std::string(*enum_arg));
936-
} else {
937-
enum_val->set_allocated_unspecified(new google::protobuf::Empty());
938-
}
946+
ARROW_ASSIGN_OR_RAISE(std::string_view enum_arg, call.GetEnumArg(i));
947+
enum_val->set_specified(std::string(enum_arg));
939948
arg->set_allocated_enum_(enum_val.release());
940949
} else if (call.HasValueArg(i)) {
941950
ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));

cpp/src/arrow/engine/substrait/extension_set.cc

Lines changed: 102 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "arrow/engine/substrait/expression_internal.h"
2626
#include "arrow/util/hash_util.h"
2727
#include "arrow/util/hashing.h"
28+
#include "arrow/util/string.h"
2829

2930
namespace arrow {
3031
namespace engine {
@@ -121,7 +122,7 @@ class IdStorageImpl : public IdStorage {
121122

122123
std::unique_ptr<IdStorage> IdStorage::Make() { return std::make_unique<IdStorageImpl>(); }
123124

124-
Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index) const {
125+
Result<std::string_view> SubstraitCall::GetEnumArg(int index) const {
125126
if (index >= size_) {
126127
return Status::Invalid("Expected Substrait call to have an enum argument at index ",
127128
index, " but it did not have enough arguments");
@@ -134,16 +135,16 @@ Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index
134135
return enum_arg_it->second;
135136
}
136137

137-
bool SubstraitCall::HasEnumArg(uint32_t index) const {
138+
bool SubstraitCall::HasEnumArg(int index) const {
138139
return enum_args_.find(index) != enum_args_.end();
139140
}
140141

141-
void SubstraitCall::SetEnumArg(uint32_t index, std::optional<std::string> enum_arg) {
142+
void SubstraitCall::SetEnumArg(int index, std::string enum_arg) {
142143
size_ = std::max(size_, index + 1);
143144
enum_args_[index] = std::move(enum_arg);
144145
}
145146

146-
Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
147+
Result<compute::Expression> SubstraitCall::GetValueArg(int index) const {
147148
if (index >= size_) {
148149
return Status::Invalid("Expected Substrait call to have a value argument at index ",
149150
index, " but it did not have enough arguments");
@@ -156,15 +157,32 @@ Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
156157
return value_arg_it->second;
157158
}
158159

159-
bool SubstraitCall::HasValueArg(uint32_t index) const {
160+
bool SubstraitCall::HasValueArg(int index) const {
160161
return value_args_.find(index) != value_args_.end();
161162
}
162163

163-
void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) {
164+
void SubstraitCall::SetValueArg(int index, compute::Expression value_arg) {
164165
size_ = std::max(size_, index + 1);
165166
value_args_[index] = std::move(value_arg);
166167
}
167168

169+
std::optional<std::vector<std::string> const*> SubstraitCall::GetOption(
170+
std::string_view option_name) const {
171+
auto opt = options_.find(std::string(option_name));
172+
if (opt == options_.end()) {
173+
return std::nullopt;
174+
}
175+
return &opt->second;
176+
}
177+
178+
void SubstraitCall::SetOption(std::string_view option_name,
179+
const std::vector<std::string_view>& option_preferences) {
180+
auto& prefs = options_[std::string(option_name)];
181+
for (std::string_view pref : option_preferences) {
182+
prefs.emplace_back(pref);
183+
}
184+
}
185+
168186
// A builder used when creating a Substrait plan from an Arrow execution plan. In
169187
// that situation we do not have a set of anchor values already defined so we keep
170188
// a map of what Ids we have seen.
@@ -645,50 +663,91 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
645663
};
646664

647665
template <typename Enum>
648-
using EnumParser = std::function<Result<Enum>(std::optional<std::string_view>)>;
649-
650-
template <typename Enum>
651-
EnumParser<Enum> GetEnumParser(const std::vector<std::string>& options) {
652-
std::unordered_map<std::string, Enum> parse_map;
653-
for (std::size_t i = 0; i < options.size(); i++) {
654-
parse_map[options[i]] = static_cast<Enum>(i + 1);
666+
class EnumParser {
667+
public:
668+
explicit EnumParser(const std::vector<std::string>& options) {
669+
for (std::size_t i = 0; i < options.size(); i++) {
670+
parse_map_[options[i]] = static_cast<Enum>(i + 1);
671+
reverse_map_[static_cast<Enum>(i + 1)] = options[i];
672+
}
655673
}
656-
return [parse_map](std::optional<std::string_view> enum_val) -> Result<Enum> {
657-
if (!enum_val) {
658-
// Assumes 0 is always kUnspecified in Enum
659-
return static_cast<Enum>(0);
674+
675+
Result<Enum> Parse(std::string_view enum_val) const {
676+
auto it = parse_map_.find(std::string(enum_val));
677+
if (it == parse_map_.end()) {
678+
return Status::NotImplemented("The value ", enum_val,
679+
" is not an expected enum value");
660680
}
661-
auto maybe_parsed = parse_map.find(std::string(*enum_val));
662-
if (maybe_parsed == parse_map.end()) {
663-
return Status::Invalid("The value ", *enum_val, " is not an expected enum value");
681+
return it->second;
682+
}
683+
684+
std::string ImplementedOptionsAsString(
685+
const std::vector<Enum>& implemented_opts) const {
686+
std::vector<std::string_view> opt_strs;
687+
for (const Enum& implemented_opt : implemented_opts) {
688+
auto it = reverse_map_.find(implemented_opt);
689+
if (it == reverse_map_.end()) {
690+
opt_strs.emplace_back("Unknown");
691+
} else {
692+
opt_strs.emplace_back(it->second);
693+
}
664694
}
665-
return maybe_parsed->second;
666-
};
667-
}
695+
return arrow::internal::JoinStrings(opt_strs, ", ");
696+
}
697+
698+
private:
699+
std::unordered_map<std::string, Enum> parse_map_;
700+
std::unordered_map<Enum, std::string> reverse_map_;
701+
};
668702

669703
enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond };
670704
static std::vector<std::string> kTemporalComponentOptions = {"YEAR", "MONTH", "DAY",
671705
"SECOND"};
672-
static EnumParser<TemporalComponent> kTemporalComponentParser =
673-
GetEnumParser<TemporalComponent>(kTemporalComponentOptions);
706+
static EnumParser<TemporalComponent> kTemporalComponentParser(kTemporalComponentOptions);
674707

675708
enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError };
676709
static std::vector<std::string> kOverflowOptions = {"SILENT", "SATURATE", "ERROR"};
677-
static EnumParser<OverflowBehavior> kOverflowParser =
678-
GetEnumParser<OverflowBehavior>(kOverflowOptions);
710+
static EnumParser<OverflowBehavior> kOverflowParser(kOverflowOptions);
679711

680712
template <typename Enum>
681-
Result<Enum> ParseEnumArg(const SubstraitCall& call, uint32_t arg_index,
713+
Result<Enum> ParseOptionOrElse(const SubstraitCall& call, std::string_view option_name,
714+
const EnumParser<Enum>& parser,
715+
const std::vector<Enum>& implemented_options,
716+
Enum fallback) {
717+
std::optional<std::vector<std::string> const*> enum_arg = call.GetOption(option_name);
718+
if (!enum_arg.has_value()) {
719+
return fallback;
720+
}
721+
std::vector<std::string> const* prefs = *enum_arg;
722+
for (const std::string& pref : *prefs) {
723+
ARROW_ASSIGN_OR_RAISE(Enum parsed, parser.Parse(pref));
724+
for (Enum implemented_opt : implemented_options) {
725+
if (implemented_opt == parsed) {
726+
return parsed;
727+
}
728+
}
729+
}
730+
731+
// Prepare error message
732+
return Status::NotImplemented(
733+
"During a call to a function with id ", call.id().uri, "#", call.id().name,
734+
" the plan requested the option ", option_name, " to be one of [",
735+
arrow::internal::JoinStrings(*prefs, ", "),
736+
"] but the only supported options are [",
737+
parser.ImplementedOptionsAsString(implemented_options), "]");
738+
}
739+
740+
template <typename Enum>
741+
Result<Enum> ParseEnumArg(const SubstraitCall& call, int arg_index,
682742
const EnumParser<Enum>& parser) {
683-
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg,
684-
call.GetEnumArg(arg_index));
685-
return parser(enum_arg);
743+
ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(arg_index));
744+
return parser.Parse(enum_val);
686745
}
687746

688747
Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
689748
int start_index) {
690749
std::vector<compute::Expression> expressions;
691-
for (uint32_t index = start_index; index < call.size(); index++) {
750+
for (int index = start_index; index < call.size(); index++) {
692751
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index));
693752
expressions.push_back(arg);
694753
}
@@ -698,13 +757,13 @@ Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
698757
ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic(
699758
const std::string& function_name) {
700759
return [function_name](const SubstraitCall& call) -> Result<compute::Expression> {
701-
ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior,
702-
ParseEnumArg(call, 0, kOverflowParser));
760+
ARROW_ASSIGN_OR_RAISE(
761+
OverflowBehavior overflow_behavior,
762+
ParseOptionOrElse(call, "overflow", kOverflowParser,
763+
{OverflowBehavior::kSilent, OverflowBehavior::kError},
764+
OverflowBehavior::kSilent));
703765
ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
704-
GetValueArgs(call, 1));
705-
if (overflow_behavior == OverflowBehavior::kUnspecified) {
706-
overflow_behavior = OverflowBehavior::kSilent;
707-
}
766+
GetValueArgs(call, 0));
708767
if (overflow_behavior == OverflowBehavior::kSilent) {
709768
return arrow::compute::call(function_name, std::move(value_args));
710769
} else if (overflow_behavior == OverflowBehavior::kError) {
@@ -727,12 +786,12 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic
727786
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
728787
/*nullable=*/true);
729788
if (kChecked) {
730-
substrait_call.SetEnumArg(0, "ERROR");
789+
substrait_call.SetOption("overflow", {"ERROR"});
731790
} else {
732-
substrait_call.SetEnumArg(0, "SILENT");
791+
substrait_call.SetOption("overflow", {"SILENT"});
733792
}
734793
for (std::size_t i = 0; i < call.arguments.size(); i++) {
735-
substrait_call.SetValueArg(static_cast<uint32_t>(i + 1), call.arguments[i]);
794+
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
736795
}
737796
return std::move(substrait_call);
738797
};
@@ -746,14 +805,14 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai
746805
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
747806
/*nullable=*/true);
748807
for (std::size_t i = 0; i < call.arguments.size(); i++) {
749-
substrait_call.SetValueArg(static_cast<uint32_t>(i), call.arguments[i]);
808+
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
750809
}
751810
return std::move(substrait_call);
752811
};
753812
}
754813

755814
ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping(
756-
const std::string& function_name, uint32_t max_args) {
815+
const std::string& function_name, int max_args) {
757816
return [function_name,
758817
max_args](const SubstraitCall& call) -> Result<compute::Expression> {
759818
if (call.size() > max_args) {

cpp/src/arrow/engine/substrait/extension_set.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,17 @@ class SubstraitCall {
119119
bool output_nullable() const { return output_nullable_; }
120120
bool is_hash() const { return is_hash_; }
121121

122-
bool HasEnumArg(uint32_t index) const;
123-
Result<std::optional<std::string_view>> GetEnumArg(uint32_t index) const;
124-
void SetEnumArg(uint32_t index, std::optional<std::string> enum_arg);
125-
Result<compute::Expression> GetValueArg(uint32_t index) const;
126-
bool HasValueArg(uint32_t index) const;
127-
void SetValueArg(uint32_t index, compute::Expression value_arg);
128-
uint32_t size() const { return size_; }
122+
bool HasEnumArg(int index) const;
123+
Result<std::string_view> GetEnumArg(int index) const;
124+
void SetEnumArg(int index, std::string enum_arg);
125+
Result<compute::Expression> GetValueArg(int index) const;
126+
bool HasValueArg(int index) const;
127+
void SetValueArg(int index, compute::Expression value_arg);
128+
std::optional<std::vector<std::string> const*> GetOption(
129+
std::string_view option_name) const;
130+
void SetOption(std::string_view option_name,
131+
const std::vector<std::string_view>& option_preferences);
132+
int size() const { return size_; }
129133

130134
private:
131135
Id id_;
@@ -134,9 +138,10 @@ class SubstraitCall {
134138
// Only needed when converting from Substrait -> Arrow aggregates. The
135139
// Arrow function name depends on whether or not there are any groups
136140
bool is_hash_;
137-
std::unordered_map<uint32_t, std::optional<std::string>> enum_args_;
138-
std::unordered_map<uint32_t, compute::Expression> value_args_;
139-
uint32_t size_ = 0;
141+
std::unordered_map<int, std::string> enum_args_;
142+
std::unordered_map<int, compute::Expression> value_args_;
143+
std::unordered_map<std::string, std::vector<std::string>> options_;
144+
int size_ = 0;
140145
};
141146

142147
/// Substrait identifies functions and custom data types using a (uri, name) pair.

0 commit comments

Comments
 (0)