Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/cmake_modules/ThirdpartyToolchain.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,13 @@ else()
endif()
endif()

# Remove these two lines once https://github.com/substrait-io/substrait/pull/342 merges
set(ENV{ARROW_SUBSTRAIT_URL}
"https://github.com/substrait-io/substrait/archive/e59008b6b202f8af06c2266991161b1e45cb056a.tar.gz"
)
set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM
"f64629cb377fcc62c9d3e8fe69fa6a4cf326f34d756e03db84843c5cce8d04cd")

if(DEFINED ENV{ARROW_SUBSTRAIT_URL})
set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}")
else()
Expand Down
37 changes: 23 additions & 14 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,15 @@ Id NormalizeFunctionName(Id id) {

} // namespace

Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
SubstraitCall* call, const ExtensionSet& ext_set,
Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
if (arg.has_enum_()) {
const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
switch (enum_val.enum_kind_case()) {
case substrait::FunctionArgument::Enum::EnumKindCase::kSpecified:
call->SetEnumArg(idx, enum_val.specified());
break;
case substrait::FunctionArgument::Enum::EnumKindCase::kUnspecified:
call->SetEnumArg(idx, std::nullopt);
break;
default:
return Status::Invalid("Unrecognized enum kind case: ",
enum_val.enum_kind_case());
Expand All @@ -80,15 +77,31 @@ Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
return Status::OK();
}

Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) {
std::vector<std::string_view> prefs;
if (opt.preference_size() == 0) {
return Status::Invalid("Invalid Substrait plan. The option ", opt.name(),
" is specified but does not list any choices");
}
for (const auto& preference : opt.preference()) {
prefs.push_back(preference);
}
call->SetOption(opt.name(), prefs);
return Status::OK();
}

Result<SubstraitCall> DecodeScalarFunction(
Id id, const substrait::Expression::ScalarFunction& scalar_fn,
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
for (int i = 0; i < scalar_fn.arguments_size(); i++) {
ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast<uint32_t>(i), &call,
ext_set, conversion_options));
ARROW_RETURN_NOT_OK(
DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options));
}
for (const auto& opt : scalar_fn.options()) {
ARROW_RETURN_NOT_OK(DecodeOption(opt, &call));
}
return std::move(call);
}
Expand Down Expand Up @@ -926,16 +939,12 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
scalar_fn->set_allocated_output_type(output_type.release());

for (uint32_t i = 0; i < call.size(); i++) {
for (int i = 0; i < call.size(); i++) {
substrait::FunctionArgument* arg = scalar_fn->add_arguments();
if (call.HasEnumArg(i)) {
auto enum_val = std::make_unique<substrait::FunctionArgument::Enum>();
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg, call.GetEnumArg(i));
if (enum_arg) {
enum_val->set_specified(std::string(*enum_arg));
} else {
enum_val->set_allocated_unspecified(new google::protobuf::Empty());
}
ARROW_ASSIGN_OR_RAISE(std::string_view enum_arg, call.GetEnumArg(i));
enum_val->set_specified(std::string(enum_arg));
arg->set_allocated_enum_(enum_val.release());
} else if (call.HasValueArg(i)) {
ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));
Expand Down
145 changes: 102 additions & 43 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/util/hash_util.h"
#include "arrow/util/hashing.h"
#include "arrow/util/string.h"

namespace arrow {
namespace engine {
Expand Down Expand Up @@ -121,7 +122,7 @@ class IdStorageImpl : public IdStorage {

std::unique_ptr<IdStorage> IdStorage::Make() { return std::make_unique<IdStorageImpl>(); }

Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index) const {
Result<std::string_view> SubstraitCall::GetEnumArg(int index) const {
if (index >= size_) {
return Status::Invalid("Expected Substrait call to have an enum argument at index ",
index, " but it did not have enough arguments");
Expand All @@ -134,16 +135,16 @@ Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index
return enum_arg_it->second;
}

bool SubstraitCall::HasEnumArg(uint32_t index) const {
bool SubstraitCall::HasEnumArg(int index) const {
return enum_args_.find(index) != enum_args_.end();
}

void SubstraitCall::SetEnumArg(uint32_t index, std::optional<std::string> enum_arg) {
void SubstraitCall::SetEnumArg(int index, std::string enum_arg) {
size_ = std::max(size_, index + 1);
enum_args_[index] = std::move(enum_arg);
}

Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
Result<compute::Expression> SubstraitCall::GetValueArg(int index) const {
if (index >= size_) {
return Status::Invalid("Expected Substrait call to have a value argument at index ",
index, " but it did not have enough arguments");
Expand All @@ -156,15 +157,32 @@ Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
return value_arg_it->second;
}

bool SubstraitCall::HasValueArg(uint32_t index) const {
bool SubstraitCall::HasValueArg(int index) const {
return value_args_.find(index) != value_args_.end();
}

void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) {
void SubstraitCall::SetValueArg(int index, compute::Expression value_arg) {
size_ = std::max(size_, index + 1);
value_args_[index] = std::move(value_arg);
}

std::optional<std::vector<std::string> const*> SubstraitCall::GetOption(
std::string_view option_name) const {
auto opt = options_.find(std::string(option_name));
if (opt == options_.end()) {
return std::nullopt;
}
return &opt->second;
}

void SubstraitCall::SetOption(std::string_view option_name,
const std::vector<std::string_view>& option_preferences) {
auto& prefs = options_[std::string(option_name)];
for (std::string_view pref : option_preferences) {
prefs.emplace_back(pref);
}
}

// A builder used when creating a Substrait plan from an Arrow execution plan. In
// that situation we do not have a set of anchor values already defined so we keep
// a map of what Ids we have seen.
Expand Down Expand Up @@ -645,50 +663,91 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
};

template <typename Enum>
using EnumParser = std::function<Result<Enum>(std::optional<std::string_view>)>;

template <typename Enum>
EnumParser<Enum> GetEnumParser(const std::vector<std::string>& options) {
std::unordered_map<std::string, Enum> parse_map;
for (std::size_t i = 0; i < options.size(); i++) {
parse_map[options[i]] = static_cast<Enum>(i + 1);
class EnumParser {
public:
explicit EnumParser(const std::vector<std::string>& options) {
for (std::size_t i = 0; i < options.size(); i++) {
parse_map_[options[i]] = static_cast<Enum>(i + 1);
reverse_map_[static_cast<Enum>(i + 1)] = options[i];
}
}
return [parse_map](std::optional<std::string_view> enum_val) -> Result<Enum> {
if (!enum_val) {
// Assumes 0 is always kUnspecified in Enum
return static_cast<Enum>(0);

Result<Enum> Parse(std::string_view enum_val) const {
auto it = parse_map_.find(std::string(enum_val));
if (it == parse_map_.end()) {
return Status::NotImplemented("The value ", enum_val,
" is not an expected enum value");
}
auto maybe_parsed = parse_map.find(std::string(*enum_val));
if (maybe_parsed == parse_map.end()) {
return Status::Invalid("The value ", *enum_val, " is not an expected enum value");
return it->second;
}

std::string ImplementedOptionsAsString(
const std::vector<Enum>& implemented_opts) const {
std::vector<std::string_view> opt_strs;
for (const Enum& implemented_opt : implemented_opts) {
auto it = reverse_map_.find(implemented_opt);
if (it == reverse_map_.end()) {
opt_strs.emplace_back("Unknown");
} else {
opt_strs.emplace_back(it->second);
}
}
return maybe_parsed->second;
};
}
return arrow::internal::JoinStrings(opt_strs, ", ");
}

private:
std::unordered_map<std::string, Enum> parse_map_;
std::unordered_map<Enum, std::string> reverse_map_;
};

enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond };
static std::vector<std::string> kTemporalComponentOptions = {"YEAR", "MONTH", "DAY",
"SECOND"};
static EnumParser<TemporalComponent> kTemporalComponentParser =
GetEnumParser<TemporalComponent>(kTemporalComponentOptions);
static EnumParser<TemporalComponent> kTemporalComponentParser(kTemporalComponentOptions);

enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError };
static std::vector<std::string> kOverflowOptions = {"SILENT", "SATURATE", "ERROR"};
static EnumParser<OverflowBehavior> kOverflowParser =
GetEnumParser<OverflowBehavior>(kOverflowOptions);
static EnumParser<OverflowBehavior> kOverflowParser(kOverflowOptions);

template <typename Enum>
Result<Enum> ParseEnumArg(const SubstraitCall& call, uint32_t arg_index,
Result<Enum> ParseOptionOrElse(const SubstraitCall& call, std::string_view option_name,
const EnumParser<Enum>& parser,
const std::vector<Enum>& implemented_options,
Enum fallback) {
std::optional<std::vector<std::string> const*> enum_arg = call.GetOption(option_name);
if (!enum_arg.has_value()) {
return fallback;
}
std::vector<std::string> const* prefs = *enum_arg;
for (const std::string& pref : *prefs) {
ARROW_ASSIGN_OR_RAISE(Enum parsed, parser.Parse(pref));
for (Enum implemented_opt : implemented_options) {
if (implemented_opt == parsed) {
return parsed;
}
}
}

// Prepare error message
return Status::NotImplemented(
"During a call to a function with id ", call.id().uri, "#", call.id().name,
" the plan requested the option ", option_name, " to be one of [",
arrow::internal::JoinStrings(*prefs, ", "),
"] but the only supported options are [",
parser.ImplementedOptionsAsString(implemented_options), "]");
}

template <typename Enum>
Result<Enum> ParseEnumArg(const SubstraitCall& call, int arg_index,
const EnumParser<Enum>& parser) {
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg,
call.GetEnumArg(arg_index));
return parser(enum_arg);
ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(arg_index));
return parser.Parse(enum_val);
}

Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
int start_index) {
std::vector<compute::Expression> expressions;
for (uint32_t index = start_index; index < call.size(); index++) {
for (int index = start_index; index < call.size(); index++) {
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index));
expressions.push_back(arg);
}
Expand All @@ -698,13 +757,13 @@ Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic(
const std::string& function_name) {
return [function_name](const SubstraitCall& call) -> Result<compute::Expression> {
ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior,
ParseEnumArg(call, 0, kOverflowParser));
ARROW_ASSIGN_OR_RAISE(
OverflowBehavior overflow_behavior,
ParseOptionOrElse(call, "overflow", kOverflowParser,
{OverflowBehavior::kSilent, OverflowBehavior::kError},
OverflowBehavior::kSilent));
ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
GetValueArgs(call, 1));
if (overflow_behavior == OverflowBehavior::kUnspecified) {
overflow_behavior = OverflowBehavior::kSilent;
}
GetValueArgs(call, 0));
if (overflow_behavior == OverflowBehavior::kSilent) {
return arrow::compute::call(function_name, std::move(value_args));
} else if (overflow_behavior == OverflowBehavior::kError) {
Expand All @@ -727,12 +786,12 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
/*nullable=*/true);
if (kChecked) {
substrait_call.SetEnumArg(0, "ERROR");
substrait_call.SetOption("overflow", {"ERROR"});
} else {
substrait_call.SetEnumArg(0, "SILENT");
substrait_call.SetOption("overflow", {"SILENT"});
}
for (std::size_t i = 0; i < call.arguments.size(); i++) {
substrait_call.SetValueArg(static_cast<uint32_t>(i + 1), call.arguments[i]);
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
}
return std::move(substrait_call);
};
Expand All @@ -746,14 +805,14 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
/*nullable=*/true);
for (std::size_t i = 0; i < call.arguments.size(); i++) {
substrait_call.SetValueArg(static_cast<uint32_t>(i), call.arguments[i]);
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
}
return std::move(substrait_call);
};
}

ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping(
const std::string& function_name, uint32_t max_args) {
const std::string& function_name, int max_args) {
return [function_name,
max_args](const SubstraitCall& call) -> Result<compute::Expression> {
if (call.size() > max_args) {
Expand Down
25 changes: 15 additions & 10 deletions cpp/src/arrow/engine/substrait/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,17 @@ class SubstraitCall {
bool output_nullable() const { return output_nullable_; }
bool is_hash() const { return is_hash_; }

bool HasEnumArg(uint32_t index) const;
Result<std::optional<std::string_view>> GetEnumArg(uint32_t index) const;
void SetEnumArg(uint32_t index, std::optional<std::string> enum_arg);
Result<compute::Expression> GetValueArg(uint32_t index) const;
bool HasValueArg(uint32_t index) const;
void SetValueArg(uint32_t index, compute::Expression value_arg);
uint32_t size() const { return size_; }
bool HasEnumArg(int index) const;
Result<std::string_view> GetEnumArg(int index) const;
void SetEnumArg(int index, std::string enum_arg);
Result<compute::Expression> GetValueArg(int index) const;
bool HasValueArg(int index) const;
void SetValueArg(int index, compute::Expression value_arg);
std::optional<std::vector<std::string> const*> GetOption(
std::string_view option_name) const;
void SetOption(std::string_view option_name,
const std::vector<std::string_view>& option_preferences);
int size() const { return size_; }

private:
Id id_;
Expand All @@ -134,9 +138,10 @@ class SubstraitCall {
// Only needed when converting from Substrait -> Arrow aggregates. The
// Arrow function name depends on whether or not there are any groups
bool is_hash_;
std::unordered_map<uint32_t, std::optional<std::string>> enum_args_;
std::unordered_map<uint32_t, compute::Expression> value_args_;
uint32_t size_ = 0;
std::unordered_map<int, std::string> enum_args_;
std::unordered_map<int, compute::Expression> value_args_;
std::unordered_map<std::string, std::vector<std::string>> options_;
int size_ = 0;
};

/// Substrait identifies functions and custom data types using a (uri, name) pair.
Expand Down
Loading