2525#include " arrow/engine/substrait/expression_internal.h"
2626#include " arrow/util/hash_util.h"
2727#include " arrow/util/hashing.h"
28+ #include " arrow/util/string.h"
2829
2930namespace arrow {
3031namespace engine {
@@ -121,7 +122,7 @@ class IdStorageImpl : public IdStorage {
121122
122123std::unique_ptr<IdStorage> IdStorage::Make () { return std::make_unique<IdStorageImpl>(); }
123124
124- Result<std::optional<std:: string_view>> SubstraitCall::GetEnumArg (uint32_t index) const {
125+ Result<std::string_view> SubstraitCall::GetEnumArg (int index) const {
125126 if (index >= size_) {
126127 return Status::Invalid (" Expected Substrait call to have an enum argument at index " ,
127128 index, " but it did not have enough arguments" );
@@ -134,16 +135,16 @@ Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index
134135 return enum_arg_it->second ;
135136}
136137
137- bool SubstraitCall::HasEnumArg (uint32_t index) const {
138+ bool SubstraitCall::HasEnumArg (int index) const {
138139 return enum_args_.find (index) != enum_args_.end ();
139140}
140141
141- void SubstraitCall::SetEnumArg (uint32_t index, std::optional<std:: string> enum_arg) {
142+ void SubstraitCall::SetEnumArg (int index, std::string enum_arg) {
142143 size_ = std::max (size_, index + 1 );
143144 enum_args_[index] = std::move (enum_arg);
144145}
145146
146- Result<compute::Expression> SubstraitCall::GetValueArg (uint32_t index) const {
147+ Result<compute::Expression> SubstraitCall::GetValueArg (int index) const {
147148 if (index >= size_) {
148149 return Status::Invalid (" Expected Substrait call to have a value argument at index " ,
149150 index, " but it did not have enough arguments" );
@@ -156,15 +157,32 @@ Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
156157 return value_arg_it->second ;
157158}
158159
159- bool SubstraitCall::HasValueArg (uint32_t index) const {
160+ bool SubstraitCall::HasValueArg (int index) const {
160161 return value_args_.find (index) != value_args_.end ();
161162}
162163
163- void SubstraitCall::SetValueArg (uint32_t index, compute::Expression value_arg) {
164+ void SubstraitCall::SetValueArg (int index, compute::Expression value_arg) {
164165 size_ = std::max (size_, index + 1 );
165166 value_args_[index] = std::move (value_arg);
166167}
167168
169+ std::optional<std::vector<std::string> const *> SubstraitCall::GetOption (
170+ std::string_view option_name) const {
171+ auto opt = options_.find (std::string (option_name));
172+ if (opt == options_.end ()) {
173+ return std::nullopt ;
174+ }
175+ return &opt->second ;
176+ }
177+
178+ void SubstraitCall::SetOption (std::string_view option_name,
179+ const std::vector<std::string_view>& option_preferences) {
180+ auto & prefs = options_[std::string (option_name)];
181+ for (std::string_view pref : option_preferences) {
182+ prefs.emplace_back (pref);
183+ }
184+ }
185+
168186// A builder used when creating a Substrait plan from an Arrow execution plan. In
169187// that situation we do not have a set of anchor values already defined so we keep
170188// a map of what Ids we have seen.
@@ -645,50 +663,91 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
645663};
646664
647665template <typename Enum>
648- using EnumParser = std::function<Result<Enum>(std::optional<std::string_view>)>;
649-
650- template < typename Enum>
651- EnumParser<Enum> GetEnumParser ( const std::vector<std::string>& options) {
652- std::unordered_map<std::string, Enum> parse_map ;
653- for (std:: size_t i = 0 ; i < options. size (); i++) {
654- parse_map[options[i]] = static_cast <Enum>(i + 1 );
666+ class EnumParser {
667+ public:
668+ explicit EnumParser ( const std::vector<std::string>& options) {
669+ for ( std::size_t i = 0 ; i < options. size (); i++ ) {
670+ parse_map_[options[i]] = static_cast < Enum>(i + 1 ) ;
671+ reverse_map_[ static_cast <Enum>(i + 1 )] = options[i];
672+ }
655673 }
656- return [parse_map](std::optional<std::string_view> enum_val) -> Result<Enum> {
657- if (!enum_val) {
658- // Assumes 0 is always kUnspecified in Enum
659- return static_cast <Enum>(0 );
674+
675+ Result<Enum> Parse (std::string_view enum_val) const {
676+ auto it = parse_map_.find (std::string (enum_val));
677+ if (it == parse_map_.end ()) {
678+ return Status::NotImplemented (" The value " , enum_val,
679+ " is not an expected enum value" );
660680 }
661- auto maybe_parsed = parse_map.find (std::string (*enum_val));
662- if (maybe_parsed == parse_map.end ()) {
663- return Status::Invalid (" The value " , *enum_val, " is not an expected enum value" );
681+ return it->second ;
682+ }
683+
684+ std::string ImplementedOptionsAsString (
685+ const std::vector<Enum>& implemented_opts) const {
686+ std::vector<std::string_view> opt_strs;
687+ for (const Enum& implemented_opt : implemented_opts) {
688+ auto it = reverse_map_.find (implemented_opt);
689+ if (it == reverse_map_.end ()) {
690+ opt_strs.emplace_back (" Unknown" );
691+ } else {
692+ opt_strs.emplace_back (it->second );
693+ }
664694 }
665- return maybe_parsed->second ;
666- };
667- }
695+ return arrow::internal::JoinStrings (opt_strs, " , " );
696+ }
697+
698+ private:
699+ std::unordered_map<std::string, Enum> parse_map_;
700+ std::unordered_map<Enum, std::string> reverse_map_;
701+ };
668702
669703enum class TemporalComponent { kUnspecified = 0 , kYear , kMonth , kDay , kSecond };
670704static std::vector<std::string> kTemporalComponentOptions = {" YEAR" , " MONTH" , " DAY" ,
671705 " SECOND" };
672- static EnumParser<TemporalComponent> kTemporalComponentParser =
673- GetEnumParser<TemporalComponent>(kTemporalComponentOptions );
706+ static EnumParser<TemporalComponent> kTemporalComponentParser (kTemporalComponentOptions );
674707
675708enum class OverflowBehavior { kUnspecified = 0 , kSilent , kSaturate , kError };
676709static std::vector<std::string> kOverflowOptions = {" SILENT" , " SATURATE" , " ERROR" };
677- static EnumParser<OverflowBehavior> kOverflowParser =
678- GetEnumParser<OverflowBehavior>(kOverflowOptions );
710+ static EnumParser<OverflowBehavior> kOverflowParser (kOverflowOptions );
679711
680712template <typename Enum>
681- Result<Enum> ParseEnumArg (const SubstraitCall& call, uint32_t arg_index,
713+ Result<Enum> ParseOptionOrElse (const SubstraitCall& call, std::string_view option_name,
714+ const EnumParser<Enum>& parser,
715+ const std::vector<Enum>& implemented_options,
716+ Enum fallback) {
717+ std::optional<std::vector<std::string> const *> enum_arg = call.GetOption (option_name);
718+ if (!enum_arg.has_value ()) {
719+ return fallback;
720+ }
721+ std::vector<std::string> const * prefs = *enum_arg;
722+ for (const std::string& pref : *prefs) {
723+ ARROW_ASSIGN_OR_RAISE (Enum parsed, parser.Parse (pref));
724+ for (Enum implemented_opt : implemented_options) {
725+ if (implemented_opt == parsed) {
726+ return parsed;
727+ }
728+ }
729+ }
730+
731+ // Prepare error message
732+ return Status::NotImplemented (
733+ " During a call to a function with id " , call.id ().uri , " #" , call.id ().name ,
734+ " the plan requested the option " , option_name, " to be one of [" ,
735+ arrow::internal::JoinStrings (*prefs, " , " ),
736+ " ] but the only supported options are [" ,
737+ parser.ImplementedOptionsAsString (implemented_options), " ]" );
738+ }
739+
740+ template <typename Enum>
741+ Result<Enum> ParseEnumArg (const SubstraitCall& call, int arg_index,
682742 const EnumParser<Enum>& parser) {
683- ARROW_ASSIGN_OR_RAISE (std::optional<std::string_view> enum_arg,
684- call.GetEnumArg (arg_index));
685- return parser (enum_arg);
743+ ARROW_ASSIGN_OR_RAISE (std::string_view enum_val, call.GetEnumArg (arg_index));
744+ return parser.Parse (enum_val);
686745}
687746
688747Result<std::vector<compute::Expression>> GetValueArgs (const SubstraitCall& call,
689748 int start_index) {
690749 std::vector<compute::Expression> expressions;
691- for (uint32_t index = start_index; index < call.size (); index++) {
750+ for (int index = start_index; index < call.size (); index++) {
692751 ARROW_ASSIGN_OR_RAISE (compute::Expression arg, call.GetValueArg (index));
693752 expressions.push_back (arg);
694753 }
@@ -698,13 +757,13 @@ Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
698757ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic (
699758 const std::string& function_name) {
700759 return [function_name](const SubstraitCall& call) -> Result<compute::Expression> {
701- ARROW_ASSIGN_OR_RAISE (OverflowBehavior overflow_behavior,
702- ParseEnumArg (call, 0 , kOverflowParser ));
760+ ARROW_ASSIGN_OR_RAISE (
761+ OverflowBehavior overflow_behavior,
762+ ParseOptionOrElse (call, " overflow" , kOverflowParser ,
763+ {OverflowBehavior::kSilent , OverflowBehavior::kError },
764+ OverflowBehavior::kSilent ));
703765 ARROW_ASSIGN_OR_RAISE (std::vector<compute::Expression> value_args,
704- GetValueArgs (call, 1 ));
705- if (overflow_behavior == OverflowBehavior::kUnspecified ) {
706- overflow_behavior = OverflowBehavior::kSilent ;
707- }
766+ GetValueArgs (call, 0 ));
708767 if (overflow_behavior == OverflowBehavior::kSilent ) {
709768 return arrow::compute::call (function_name, std::move (value_args));
710769 } else if (overflow_behavior == OverflowBehavior::kError ) {
@@ -727,12 +786,12 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic
727786 SubstraitCall substrait_call (substrait_fn_id, call.type .GetSharedPtr (),
728787 /* nullable=*/ true );
729788 if (kChecked ) {
730- substrait_call.SetEnumArg ( 0 , " ERROR" );
789+ substrait_call.SetOption ( " overflow " , { " ERROR" } );
731790 } else {
732- substrait_call.SetEnumArg ( 0 , " SILENT" );
791+ substrait_call.SetOption ( " overflow " , { " SILENT" } );
733792 }
734793 for (std::size_t i = 0 ; i < call.arguments .size (); i++) {
735- substrait_call.SetValueArg (static_cast <uint32_t >(i + 1 ), call.arguments [i]);
794+ substrait_call.SetValueArg (static_cast <int >(i), call.arguments [i]);
736795 }
737796 return std::move (substrait_call);
738797 };
@@ -746,14 +805,14 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai
746805 SubstraitCall substrait_call (substrait_fn_id, call.type .GetSharedPtr (),
747806 /* nullable=*/ true );
748807 for (std::size_t i = 0 ; i < call.arguments .size (); i++) {
749- substrait_call.SetValueArg (static_cast <uint32_t >(i), call.arguments [i]);
808+ substrait_call.SetValueArg (static_cast <int >(i), call.arguments [i]);
750809 }
751810 return std::move (substrait_call);
752811 };
753812}
754813
755814ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping (
756- const std::string& function_name, uint32_t max_args) {
815+ const std::string& function_name, int max_args) {
757816 return [function_name,
758817 max_args](const SubstraitCall& call) -> Result<compute::Expression> {
759818 if (call.size () > max_args) {
0 commit comments