diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 5d7d66225e1..45880bce507 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -31,6 +31,7 @@ #include "arrow/util/make_unique.h" #include "arrow/visit_scalar_inline.h" + namespace arrow { using internal::checked_cast; @@ -159,21 +160,17 @@ Result FromProto(const substrait::Expression& expr, ARROW_ASSIGN_OR_RAISE(auto decoded_function, ext_set.DecodeFunction(scalar_fn.function_reference())); + ARROW_ASSIGN_OR_RAISE(auto arrow_function, ext_set.GetFunctionMap().GetArrowFromSubstrait(decoded_function.name.to_string())); + return arrow_function(scalar_fn); + } - std::vector arguments(scalar_fn.args_size()); - for (int i = 0; i < scalar_fn.args_size(); ++i) { - ARROW_ASSIGN_OR_RAISE(arguments[i], FromProto(scalar_fn.args(i), ext_set)); - } - - auto func_name = decoded_function.name.to_string(); - if (func_name != "cast") { - return compute::call(func_name, std::move(arguments)); - } else { - ARROW_ASSIGN_OR_RAISE(auto output_type_desc, - FromProto(scalar_fn.output_type(), ext_set)); - auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first)); - return compute::call(func_name, std::move(arguments), std::move(cast_options)); - } + case substrait::Expression::kEnum: { + auto enum_expr = expr.enum_(); + if(enum_expr.has_specified()){ + return compute::literal(std::move(enum_expr.specified())); + } else { + return Status::Invalid("Substrait Enum value not specified"); + } } default: diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index a30c740b181..5406d80f3f0 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -20,6 +20,11 @@ #include #include +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/type_internal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/string_view.h" @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) == arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) == substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector ConvertSubstraitArguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set; + arrow::compute::Expression expression; + std::vector func_args; + for(int i=0; i value; + for(size_t i = 0; iadd_args()->CopyFrom(*value); + } + return *substrait_call; +} + +substrait::Expression::ScalarFunction ConvertArrowEnumArguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction* substrait_call, ExtensionSet* ext_set, std::string enum_value){ + substrait::Expression::Enum options; + options.set_specified(enum_value); + substrait_call->add_args()->set_allocated_enum_(&options); + return ConvertArrowArguments(call, substrait_call, ext_set); +} + + +SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("add", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); + } else { + return arrow::compute::call("add_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } + }; + +SubstraitToArrow substrait_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("subtract", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating subtract"); + } else { + return arrow::compute::call("subtract_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_multiply_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("multiply", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating multiply"); + } else { + return arrow::compute::call("mutiply_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_divide_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("divide", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating divide"); + } else { + return arrow::compute::call("divide_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_modulus_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("abs", ConvertSubstraitArguments(call)); +}; + +ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "ERROR"); + }; + +ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "SILENT"); +}; + +ArrowToSubstrait arrow_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "SILENT") ; +}; + +ArrowToSubstrait arrow_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "SILENT"); +}; + + +ArrowToSubstrait arrow_divide_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("divide")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_divide_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("divide")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "SILENT"); +}; + +ArrowToSubstrait arrow_abs_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("modulus")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +// Boolean Functions mappings +SubstraitToArrow substrait_not_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("invert", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_or_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("or_kleene", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_and_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("and_kleene", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_xor_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("xor", ConvertSubstraitArguments(call)); +}; + +ArrowToSubstrait arrow_invert_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("not")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_or_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("or")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_and_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("and")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_xor_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("xor")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +// Comparison Functions mapping +SubstraitToArrow substrait_lt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("less", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_gt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("greater", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("less_equal", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_gte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("greater_equal", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_not_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("not_equal", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("equal", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_is_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("is_null", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_is_not_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("is_valid", ConvertSubstraitArguments(call)); +}; + +SubstraitToArrow substrait_is_not_distinct_from_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + std::vector func_args = ConvertSubstraitArguments(call); + auto null_check_1 = arrow::compute::call("is_null", {func_args[0]}); + auto null_check_2 = arrow::compute::call("is_null", {func_args[1]}); + if(null_check_1.IsNullLiteral() && null_check_1.IsNullLiteral()){ + return arrow::compute::call("not_equal", {null_check_1, null_check_2}); + } + return arrow::compute::call("not_equal", func_args); +}; + +ArrowToSubstrait arrow_less_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("lt")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_greater_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("gt")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_less_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("lte")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_greater_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("gte")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("equal")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_not_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("not_equal")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_is_null_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("is_null")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_is_valid_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("is_not_null")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +// Strings function mapping +SubstraitToArrow substrait_like_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + return arrow::compute::call("match_like", {func_args[0]}, compute::MatchSubstringOptions(func_args[1].ToString())); +}; + +SubstraitToArrow substrait_substring_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + auto start = func_args[1].literal()->scalar_as(); + auto stop = func_args[2].literal()->scalar_as(); + return arrow::compute::call("utf8_slice_codeunits", {func_args[0]}, compute::SliceOptions(static_cast(start.value), static_cast(stop.value))); +}; + +SubstraitToArrow substrait_concat_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + arrow::StringBuilder builder; + builder.Append(func_args[0].ToString()); + builder.Append(func_args[1].ToString()); + auto strings_datum = arrow::Datum(*builder.Finish()); + auto separator_datum = arrow::Datum(""); + return arrow::compute::call("binary_join", {arrow::compute::Expression(strings_datum), arrow::compute::Expression(separator_datum)}); +}; + +ArrowToSubstrait arrow_match_like_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("like")); + substrait_call.set_function_reference(function_reference); + + arrow::compute::Expression expression_1, expression_2; + std::unique_ptr string_1, string_2; + expression_1 = call.arguments[0]; + string_1 = ToProto(expression_1, ext_set).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_1); + + auto pattern_string = std::dynamic_pointer_cast(call.options)->pattern; + expression_2 = arrow::compute::Expression(arrow::Datum(pattern_string)); + string_2 = ToProto(expression_2, ext_set).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_2); + + return std::move(substrait_call); +}; + +ArrowToSubstrait arrow_utf8_slice_codeunits_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("substring")); + substrait_call.set_function_reference(function_reference); + arrow::compute::Expression expression_1, expression_2, expression_3; + std::unique_ptr string, start, stop; + expression_1 = call.arguments[0]; + string = ToProto(expression_1, ext_set).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string); + + auto start_index = std::dynamic_pointer_cast(call.options)->start; + auto stop_index = std::dynamic_pointer_cast(call.options)->stop; + expression_2 = arrow::compute::Expression(arrow::Datum(start_index)); + expression_3 = arrow::compute::Expression(arrow::Datum(stop_index)); + start = ToProto(expression_2, ext_set).ValueOrDie(); + stop = ToProto(expression_3, ext_set).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*start); + substrait_call.add_args()->CopyFrom(*stop); + + return std::move(substrait_call); +}; + +ArrowToSubstrait arrow_binary_join_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("concat")); + substrait_call.set_function_reference(function_reference); + arrow::compute::Expression expression_1, expression_2; + std::unique_ptr string_1, string_2; + + auto strings_list = call.arguments[0].literal()->make_array(); + expression_1 = arrow::compute::Expression(*(strings_list->GetScalar(0))); + expression_2 = arrow::compute::Expression(*(strings_list->GetScalar(1))); + + string_1 = ToProto(expression_1, ext_set).ValueOrDie(); + string_2 = ToProto(expression_2, ext_set).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_1); + substrait_call.add_args()->CopyFrom(*string_2); + return std::move(substrait_call); +}; + +// Cast function mapping +SubstraitToArrow substrait_cast_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + ExtensionSet ext_set; + ARROW_ASSIGN_OR_RAISE(auto output_type_desc, + FromProto(call.output_type(), ext_set)); + auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first)); + return compute::call("cast", {ConvertSubstraitArguments(call)[0]}, std::move(cast_options)); +}; + +ArrowToSubstrait arrow_cast_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("cast")); + substrait_call.set_function_reference(function_reference); + + auto arrow_to_type = std::dynamic_pointer_cast(call.options)->to_type; + ARROW_ASSIGN_OR_RAISE(auto substrait_to_type, ToProto(*arrow_to_type, false, ext_set)); + substrait_call.set_allocated_output_type(substrait_to_type.get()); + + auto expression = call.arguments[0]; + ARROW_ASSIGN_OR_RAISE(auto value, ToProto(expression, ext_set)); + substrait_call.add_args()->CopyFrom(*value); + + return substrait_call; +}; + +// Datetime functions mapping +SubstraitToArrow substrait_extract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = ConvertSubstraitArguments(call); + if(func_args[0].ToString() == "YEAR"){ + return arrow::compute::call("year", {func_args[1]}); + } else if (func_args[0].ToString() == "MONTH") { + return arrow::compute::call("month", {func_args[1]}); + } else if (func_args[0].ToString() == "DAY") { + return arrow::compute::call("day", {func_args[1]}); + } else { + return arrow::compute::call("second", {func_args[1]}); + } +}; + +ArrowToSubstrait arrow_year_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "YEAR"); +}; + +ArrowToSubstrait arrow_month_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "MONTH"); +}; + +ArrowToSubstrait arrow_day_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "DAY"); +}; + +ArrowToSubstrait arrow_second_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowEnumArguments(call, &substrait_call, ext_set, "SECOND"); +}; + +// Substrait Datetime add/subtract mappings should work for datetime intervals functions as well +SubstraitToArrow substrait_datetime_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("add", ConvertSubstraitArguments(call), compute::ArithmeticOptions()); + }; + +SubstraitToArrow substrait_datetime_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("subtract", ConvertSubstraitArguments(call), compute::ArithmeticOptions()); + }; + +ArrowToSubstrait arrow_datetime_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_datetime_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_datetime_add_intervals_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("add_intervals")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +ArrowToSubstrait arrow_datetime_subtract_intervals_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set->EncodeFunction("subtract_intervals")); + substrait_call.set_function_reference(function_reference); + return ConvertArrowArguments(call, &substrait_call, ext_set); +}; + +// Aggregate functions mapping +SubstraitToArrow substrait_aggregate_sum_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("sum", {ConvertSubstraitArguments(call)[1]}, compute::ScalarAggregateOptions()); +}; + +SubstraitToArrow substrait_aggregate_avg_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("avg", {ConvertSubstraitArguments(call)[1]}, compute::ScalarAggregateOptions()); +}; + namespace { struct ExtensionIdRegistryImpl : ExtensionIdRegistry { @@ -288,6 +808,11 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return Status::OK(); } + Status RegisterFunctionMapping(Id id, SubstraitToArrow conversion_func) override { + DCHECK_OK(functions_map.AddSubstraitToArrow(id.name.to_string(), conversion_func)); + return RegisterFunction(id, id.name.to_string()); + } + Status RegisterFunction(Id id, std::string arrow_function_name) override { DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); @@ -432,17 +957,45 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); } - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - "equal", - "is_not_distinct_from", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } + // registering arithmetic function mappings + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "add"}, substrait_add_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "subtract"}, substrait_subtract_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "multiply"}, substrait_multiply_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "divide"}, substrait_divide_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "modulus"}, substrait_modulus_to_arrow)); + + // registering boolean function mappings + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "not"}, substrait_not_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "and"}, substrait_and_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "or"}, substrait_or_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "xor"}, substrait_xor_to_arrow)); + + // registering comparison function mappings + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "lt"}, substrait_lt_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "gt"}, substrait_gt_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "lte"}, substrait_lte_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "gte"}, substrait_gte_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "equal"}, substrait_equal_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "not_equal"}, substrait_not_equal_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "is_null"}, substrait_is_null_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "is_not_null"}, substrait_is_not_null_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "is_not_distinct_from"}, substrait_is_not_distinct_from_to_arrow)); + + // registering string function mappings + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "like"}, substrait_like_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "substring"}, substrait_substring_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "concat"}, substrait_concat_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "substring"}, substrait_substring_to_arrow)); + + // registering cast function mapping + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "cast"}, substrait_cast_to_arrow)); + + // registering datetime function mappings + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "extract"}, substrait_extract_to_arrow)); + + // registering aggregate function mappings + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "sum"}, substrait_aggregate_sum_to_arrow)); + DCHECK_OK(RegisterFunctionMapping({kArrowExtTypesUri, "avg"}, substrait_aggregate_avg_to_arrow)); } }; diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index de013015a72..d438249cfd5 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -22,16 +22,38 @@ #include #include +#include "arrow/compute/function.h" +#include "arrow/compute/exec/expression.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" #include "arrow/util/optional.h" #include "arrow/util/string_view.h" #include "arrow/util/hash_util.h" +#include "substrait/expression.pb.h" // IWYU pragma: export namespace arrow { namespace engine { +class ExtensionSet; +using ArrowToSubstrait = std::function(const arrow::compute::Expression::Call&, arrow::engine::ExtensionSet*)>; +using SubstraitToArrow = std::function(const substrait::Expression::ScalarFunction&)>; + +class FunctionMapping { + + std::unordered_map substrait_to_arrow; + std::unordered_map arrow_to_substrait; + + public: + // Registration API + Status AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func); + Status AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func); + + Result GetArrowFromSubstrait(std::string name) const; + Result GetSubstraitFromArrow(std::string name) const; +}; + + /// Substrait identifies functions and custom data types using a (uri, name) pair. /// /// This registry is a bidirectional mapping between Substrait IDs and their corresponding @@ -89,11 +111,13 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { Id id; const std::string& function_name; }; + arrow::engine::FunctionMapping functions_map; virtual util::optional GetFunction(Id) const = 0; virtual util::optional GetFunction( util::string_view arrow_function_name) const = 0; virtual Status CanRegisterFunction(Id, const std::string& arrow_function_name) const = 0; + virtual Status RegisterFunctionMapping(Id id, SubstraitToArrow conversion_func) = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; }; @@ -243,7 +267,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// value larger than the actual number of functions. This behavior may change in the /// future; see ARROW-15583. std::size_t num_functions() const { return functions_.size(); } - + private: const ExtensionIdRegistry* registry_; @@ -261,11 +285,17 @@ class ARROW_ENGINE_EXPORT ExtensionSet { // Map from function names to anchor values. Used during Arrow->Substrait // and built as the plan is created. std::unordered_map functions_map_; - + Status CheckHasUri(util::string_view uri); void AddUri(std::pair uri); - Status AddUri(Id id); + Status AddUri(Id id); + + public: + FunctionMapping GetFunctionMap() const { return registry_->functions_map;} + }; + + } // namespace engine } // namespace arrow