diff --git a/velox/substrait/CMakeLists.txt b/velox/substrait/CMakeLists.txt index 110edc9d7de..3367079cb63 100644 --- a/velox/substrait/CMakeLists.txt +++ b/velox/substrait/CMakeLists.txt @@ -47,9 +47,11 @@ set(SRCS SubstraitToVeloxExpr.cpp SubstraitToVeloxPlan.cpp TypeUtils.cpp + SubstraitExtensionCollector.cpp VeloxToSubstraitExpr.cpp VeloxToSubstraitPlan.cpp - VeloxToSubstraitType.cpp) + VeloxToSubstraitType.cpp + VeloxSubstraitSignature.cpp) add_library(velox_substrait_plan_converter ${SRCS}) target_include_directories(velox_substrait_plan_converter diff --git a/velox/substrait/SubstraitExtensionCollector.cpp b/velox/substrait/SubstraitExtensionCollector.cpp new file mode 100644 index 00000000000..fcba9259e4e --- /dev/null +++ b/velox/substrait/SubstraitExtensionCollector.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/substrait/SubstraitExtensionCollector.h" + +namespace facebook::velox::substrait { + +int SubstraitExtensionCollector::getReferenceNumber( + const std::string& functionName, + const std::vector& arguments) { + const auto& substraitFunctionSignature = + VeloxSubstraitSignature::toSubstraitSignature(functionName, arguments); + // TODO: Currently we treat all velox registry based function signatures as + // custom substrait extension, so no uri link and leave it as empty. + return getReferenceNumber({"", substraitFunctionSignature}); +} + +int SubstraitExtensionCollector::getReferenceNumber( + const std::string& functionName, + const std::vector& arguments, + const core::AggregationNode::Step aggregationStep) { + // TODO: Ignore aggregationStep for now, will refactor when introduce velox + // registry for function signature binding + return getReferenceNumber(functionName, arguments); +} + +template +void SubstraitExtensionCollector::BiDirectionHashMap::putIfAbsent( + const int& key, + const T& value) { + if (forwardMap_.find(key) == forwardMap_.end()) { + forwardMap_[key] = value; + } + if (reverseMap_.find(value) == reverseMap_.end()) { + reverseMap_[value] = key; + } +} + +void SubstraitExtensionCollector::addExtensionsToPlan( + ::substrait::Plan* plan) const { + using SimpleExtensionURI = ::substrait::extensions::SimpleExtensionURI; + // Currently we don't introduce any substrait extension YAML files, so always + // only have one URI. + SimpleExtensionURI* extensionUri = plan->add_extension_uris(); + extensionUri->set_extension_uri_anchor(1); + + for (const auto& [referenceNum, functionId] : + extensionFunctions_->forwardMap()) { + auto extensionFunction = + plan->add_extensions()->mutable_extension_function(); + extensionFunction->set_extension_uri_reference( + extensionUri->extension_uri_anchor()); + extensionFunction->set_function_anchor(referenceNum); + extensionFunction->set_name(functionId.signature); + } +} + +SubstraitExtensionCollector::SubstraitExtensionCollector() { + extensionFunctions_ = + std::make_shared>(); +} + +int SubstraitExtensionCollector::getReferenceNumber( + const ExtensionFunctionId& extensionFunctionId) { + const auto& extensionFunctionAnchorIt = + extensionFunctions_->reverseMap().find(extensionFunctionId); + if (extensionFunctionAnchorIt != extensionFunctions_->reverseMap().end()) { + return extensionFunctionAnchorIt->second; + } + ++functionReferenceNumber; + extensionFunctions_->putIfAbsent( + functionReferenceNumber, extensionFunctionId); + return functionReferenceNumber; +} + +} // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitExtensionCollector.h b/velox/substrait/SubstraitExtensionCollector.h new file mode 100644 index 00000000000..bcc5768caba --- /dev/null +++ b/velox/substrait/SubstraitExtensionCollector.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/core/Expressions.h" +#include "velox/core/PlanNode.h" +#include "velox/substrait/VeloxSubstraitSignature.h" +#include "velox/substrait/proto/substrait/algebra.pb.h" +#include "velox/substrait/proto/substrait/plan.pb.h" +#include "velox/type/Type.h" + +namespace facebook::velox::substrait { + +struct ExtensionFunctionId { + /// Substrait extension YAML file uri. + std::string uri; + + /// Substrait signature used in the function extension declaration is a + /// combination of the name of the function along with a list of input + /// argument types.The format is as follows : :__..._ for more + /// detail information about the argument type please refer to link + /// https://substrait.io/extensions/#function-signature-compound-names. + std::string signature; + + bool operator==(const ExtensionFunctionId& other) const { + return (uri == other.uri && signature == other.signature); + } +}; + +/// Assigns unique IDs to function signatures using ExtensionFunctionId. +class SubstraitExtensionCollector { + public: + SubstraitExtensionCollector(); + + /// Given a scalar function name and argument types, return the functionId + /// using ExtensionFunctionId. + int getReferenceNumber( + const std::string& functionName, + const std::vector& arguments); + + /// Given an aggregate function name and argument types and aggregation Step, + /// return the functionId using ExtensionFunctionId. + int getReferenceNumber( + const std::string& functionName, + const std::vector& arguments, + core::AggregationNode::Step aggregationStep); + + /// Add extension functions to Substrait plan. + void addExtensionsToPlan(::substrait::Plan* plan) const; + + private: + /// A bi-direction hash map to keep the relation between reference number and + /// either function or type signature. + /// @tparam ExtensionFunctionId + template + class BiDirectionHashMap { + public: + /// For forwardMap_, if the specified key is not already associated with a + /// value, associates it with the given value and returns, else do nothing. + /// For reverseMap_, if the specified value is not already associated with a + /// key, associate it with the given key and returns, else do nothing. + void putIfAbsent(const int& key, const T& value); + + const std::unordered_map forwardMap() const { + return forwardMap_; + } + + const std::unordered_map& reverseMap() const { + return reverseMap_; + } + + private: + std::unordered_map forwardMap_; + std::unordered_map reverseMap_; + }; + + /// Assigns unique IDs to function signatures using ExtensionFunctionId. + int getReferenceNumber(const ExtensionFunctionId& extensionFunctionId); + + int functionReferenceNumber = -1; + std::shared_ptr> extensionFunctions_; +}; + +using SubstraitExtensionCollectorPtr = + std::shared_ptr; + +} // namespace facebook::velox::substrait + +namespace std { + +/// Hash function of facebook::velox::substrait::ExtensionFunctionId. +template <> +struct hash { + size_t operator()( + const facebook::velox::substrait::ExtensionFunctionId& k) const { + size_t val = hash()(k.uri); + val = val * 31 + hash()(k.signature); + return val; + } +}; + +}; // namespace std diff --git a/velox/substrait/VeloxSubstraitSignature.cpp b/velox/substrait/VeloxSubstraitSignature.cpp new file mode 100644 index 00000000000..084e1c436c6 --- /dev/null +++ b/velox/substrait/VeloxSubstraitSignature.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/substrait/VeloxSubstraitSignature.h" +#include "velox/functions/FunctionRegistry.h" + +namespace facebook::velox::substrait { + +std::string VeloxSubstraitSignature::toSubstraitSignature( + const TypeKind typeKind) { + switch (typeKind) { + case TypeKind::BOOLEAN: + return "bool"; + case TypeKind::TINYINT: + return "i8"; + case TypeKind::SMALLINT: + return "i16"; + case TypeKind::INTEGER: + return "i32"; + case TypeKind::BIGINT: + return "i64"; + case TypeKind::REAL: + return "fp32"; + case TypeKind::DOUBLE: + return "fp64"; + case TypeKind::VARCHAR: + return "str"; + case TypeKind::VARBINARY: + return "vbin"; + case TypeKind::TIMESTAMP: + return "ts"; + case TypeKind::DATE: + return "date"; + case TypeKind::SHORT_DECIMAL: + return "dec"; + case TypeKind::LONG_DECIMAL: + return "dec"; + case TypeKind::ARRAY: + return "list"; + case TypeKind::MAP: + return "map"; + case TypeKind::ROW: + return "struct"; + case TypeKind::UNKNOWN: + return "u!name"; + default: + VELOX_UNSUPPORTED( + "Substrait type signature conversion not supported for type {}.", + mapTypeKindToName(typeKind)); + } +} + +std::string VeloxSubstraitSignature::toSubstraitSignature( + const std::string& functionName, + const std::vector& arguments) { + if (arguments.empty()) { + return functionName; + } + std::vector substraitTypeSignatures; + substraitTypeSignatures.reserve(arguments.size()); + for (const auto& type : arguments) { + substraitTypeSignatures.emplace_back(toSubstraitSignature(type->kind())); + } + return functionName + ":" + folly::join("_", substraitTypeSignatures); +} + +} // namespace facebook::velox::substrait diff --git a/velox/substrait/VeloxSubstraitSignature.h b/velox/substrait/VeloxSubstraitSignature.h new file mode 100644 index 00000000000..6f24312ea34 --- /dev/null +++ b/velox/substrait/VeloxSubstraitSignature.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/type/Type.h" + +namespace facebook::velox::substrait { + +class VeloxSubstraitSignature { + public: + /// Given a velox type kind, return the Substrait type signature, throw if no + /// match found, Substrait signature used in the function extension + /// declaration is a combination of the name of the function along with a list + /// of input argument types.The format is as follows : :__..._ for more + /// detail information about the argument type please refer to link + /// https://substrait.io/extensions/#function-signature-compound-names. + static std::string toSubstraitSignature(const TypeKind typeKind); + + /// Given a velox scalar function name and argument types, return the + /// substrait function signature. + static std::string toSubstraitSignature( + const std::string& functionName, + const std::vector& arguments); +}; + +} // namespace facebook::velox::substrait diff --git a/velox/substrait/VeloxToSubstraitExpr.cpp b/velox/substrait/VeloxToSubstraitExpr.cpp index 89dcee0a89b..eca1886d070 100644 --- a/velox/substrait/VeloxToSubstraitExpr.cpp +++ b/velox/substrait/VeloxToSubstraitExpr.cpp @@ -140,8 +140,14 @@ const ::substrait::Expression& VeloxToSubstraitExprConvertor::toSubstraitExpr( ::substrait::Expression_ScalarFunction* scalarExpr = substraitExpr->mutable_scalar_function(); - // TODO need to change yaml file to register function, now is dummy. - scalarExpr->set_function_reference(functionMap_[functionName]); + std::vector types; + types.reserve(callTypeExpr->inputs().size()); + for (auto& typedExpr : callTypeExpr->inputs()) { + types.emplace_back(typedExpr->type()); + } + + scalarExpr->set_function_reference( + extensionCollector_->getReferenceNumber(functionName, types)); for (auto& arg : inputs) { scalarExpr->add_arguments()->mutable_value()->MergeFrom( diff --git a/velox/substrait/VeloxToSubstraitExpr.h b/velox/substrait/VeloxToSubstraitExpr.h index 080dcc00827..66274d3be2c 100644 --- a/velox/substrait/VeloxToSubstraitExpr.h +++ b/velox/substrait/VeloxToSubstraitExpr.h @@ -18,6 +18,7 @@ #include "velox/core/PlanNode.h" +#include "velox/substrait/SubstraitExtensionCollector.h" #include "velox/substrait/VeloxToSubstraitType.h" #include "velox/substrait/proto/substrait/algebra.pb.h" @@ -25,11 +26,9 @@ namespace facebook::velox::substrait { class VeloxToSubstraitExprConvertor { public: - /// @param functionMap: A pre-constructed map - /// storing the relations between the function name and the function id. explicit VeloxToSubstraitExprConvertor( - const std::unordered_map& functionMap) - : functionMap_(functionMap) {} + const SubstraitExtensionCollectorPtr& extensionCollector) + : extensionCollector_(extensionCollector) {} /// Convert Velox Expression to Substrait Expression. /// @param arena Arena to use for allocating Substrait plan objects. @@ -95,11 +94,12 @@ class VeloxToSubstraitExprConvertor { google::protobuf::Arena& arena, const velox::variant& variantValue); - std::shared_ptr typeConvertor_; + VeloxToSubstraitTypeConvertorPtr typeConvertor_; - /// The map storing the relations between the function name and the function - /// id. - std::unordered_map functionMap_; + SubstraitExtensionCollectorPtr extensionCollector_; }; +using VeloxToSubstraitExprConvertorPtr = + std::shared_ptr; + } // namespace facebook::velox::substrait diff --git a/velox/substrait/VeloxToSubstraitPlan.cpp b/velox/substrait/VeloxToSubstraitPlan.cpp index 4fb8b6710d6..b9795688842 100644 --- a/velox/substrait/VeloxToSubstraitPlan.cpp +++ b/velox/substrait/VeloxToSubstraitPlan.cpp @@ -46,21 +46,15 @@ ::substrait::AggregationPhase toAggregationPhase( ::substrait::Plan& VeloxToSubstraitPlanConvertor::toSubstrait( google::protobuf::Arena& arena, const core::PlanNodePtr& plan) { - // Assume only accepts a single plan fragment. - - // Construct the function map based on the Velox plan. - constructFunctionMap(); - + // Construct the extension colllector. + extensionCollector_ = std::make_shared(); // Construct the expression converter. exprConvertor_ = - std::make_shared(functionMap_); + std::make_shared(extensionCollector_); - ::substrait::Plan* substraitPlan = + auto substraitPlan = google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena); - // Add Extension Functions. - substraitPlan->MergeFrom(addExtensionFunc(arena)); - // Add unknown type in extension. auto unknownType = substraitPlan->add_extensions()->mutable_extension_type(); @@ -73,6 +67,10 @@ ::substrait::Plan& VeloxToSubstraitPlanConvertor::toSubstrait( substraitPlan->add_relations()->mutable_root(); toSubstrait(arena, plan, rootRel->mutable_input()); + + // Add extensions for all functions and types seen in the plan. + extensionCollector_->addExtensionsToPlan(substraitPlan); + // Set RootRel names. for (const auto& name : plan->outputType()->names()) { rootRel->add_names(name); @@ -278,6 +276,9 @@ void VeloxToSubstraitPlanConvertor::toSubstrait( // Aggregation function name. const auto& funName = aggregatesExpr->name(); // set aggFunction args. + + std::vector arguments; + arguments.reserve(aggregatesExpr->inputs().size()); for (const auto& expr : aggregatesExpr->inputs()) { // If the expr is CallTypedExpr, people need to do project firstly. if (auto aggregatesExprInput = @@ -286,15 +287,15 @@ void VeloxToSubstraitPlanConvertor::toSubstrait( } else { aggFunction->add_arguments()->mutable_value()->MergeFrom( exprConvertor_->toSubstraitExpr(arena, expr, inputType)); + + arguments.emplace_back(expr->type()); } } - // Set substrait aggregate Function reference and output type. - if (functionMap_.find(funName) != functionMap_.end()) { - aggFunction->set_function_reference(functionMap_[funName]); - } else { - VELOX_NYI("Couldn't find the aggregate function '{}' ", funName); - } + auto referenceNumber = extensionCollector_->getReferenceNumber( + funName, arguments, aggregateNode->step()); + + aggFunction->set_function_reference(referenceNumber); aggFunction->mutable_output_type()->MergeFrom( typeConvertor_->toSubstraitType(arena, aggregatesExpr->type())); @@ -307,91 +308,4 @@ void VeloxToSubstraitPlanConvertor::toSubstrait( aggregateRel->mutable_common()->mutable_direct(); } -void VeloxToSubstraitPlanConvertor::constructFunctionMap() { - // TODO: Fetch all functions from velox's registry. - - functionMap_["plus"] = 0; - functionMap_["multiply"] = 1; - functionMap_["lt"] = 2; - functionMap_["divide"] = 3; - functionMap_["count"] = 4; - functionMap_["sum"] = 5; - functionMap_["mod"] = 6; - functionMap_["eq"] = 7; - functionMap_["row_constructor"] = 8; - functionMap_["avg"] = 9; -} - -::substrait::Plan& VeloxToSubstraitPlanConvertor::addExtensionFunc( - google::protobuf::Arena& arena) { - // TODO: Fetch all functions from velox's registry and add them into substrait - // extensions. - // Now we just work around this part and add one function as dummy version to - // pass filter and project round-trip test. - auto substraitPlan = - google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena); - - auto extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(0); - extensionFunction->set_name("add:opt_i32_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(1); - extensionFunction->set_name("multiply:opt_i32_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(1); - extensionFunction->set_function_anchor(2); - extensionFunction->set_name("lt:i32_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(3); - extensionFunction->set_name("divide:i32_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(4); - extensionFunction->set_name("count:opt_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(5); - extensionFunction->set_name("sum:opt_f64"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(6); - extensionFunction->set_name("modulus:i32_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(7); - extensionFunction->set_name("equal:i64_i64"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(8); - extensionFunction->set_name("row_constructor:f64_i32"); - - extensionFunction = - substraitPlan->add_extensions()->mutable_extension_function(); - extensionFunction->set_extension_uri_reference(0); - extensionFunction->set_function_anchor(9); - extensionFunction->set_name("avg:opt_f64"); - return *substraitPlan; -} - } // namespace facebook::velox::substrait diff --git a/velox/substrait/VeloxToSubstraitPlan.h b/velox/substrait/VeloxToSubstraitPlan.h index 26905234597..a93a83b96b4 100644 --- a/velox/substrait/VeloxToSubstraitPlan.h +++ b/velox/substrait/VeloxToSubstraitPlan.h @@ -23,6 +23,7 @@ #include "velox/core/PlanNode.h" #include "velox/type/Type.h" +#include "velox/substrait/SubstraitExtensionCollector.h" #include "velox/substrait/VeloxToSubstraitExpr.h" #include "velox/substrait/proto/substrait/algebra.pb.h" #include "velox/substrait/proto/substrait/plan.pb.h" @@ -72,21 +73,17 @@ class VeloxToSubstraitPlanConvertor { const std::shared_ptr& aggregateNode, ::substrait::AggregateRel* aggregateRel); - /// Construct the function map between the Velox function name and index. - void constructFunctionMap(); - - /// Fetch all functions from Velox's registry and create Substrait extensions - /// for these. - ::substrait::Plan& addExtensionFunc(google::protobuf::Arena& arena); - /// The Expression converter used to convert Velox representations into /// Substrait expressions. - std::shared_ptr exprConvertor_; + VeloxToSubstraitExprConvertorPtr exprConvertor_; + /// The Type converter used to conver velox representation into Substrait + /// type. std::shared_ptr typeConvertor_; - /// The map storing the relations between the function name and the function - /// id. Will be constructed based on the Velox representation. - std::unordered_map functionMap_; + /// The Extension collector storing the relations between the function + /// signature and the function reference number. + SubstraitExtensionCollectorPtr extensionCollector_; }; + } // namespace facebook::velox::substrait diff --git a/velox/substrait/VeloxToSubstraitType.h b/velox/substrait/VeloxToSubstraitType.h index 4b8b0a2fc00..5a1f297a7d4 100644 --- a/velox/substrait/VeloxToSubstraitType.h +++ b/velox/substrait/VeloxToSubstraitType.h @@ -36,4 +36,7 @@ class VeloxToSubstraitTypeConvertor { const velox::TypePtr& type); }; +using VeloxToSubstraitTypeConvertorPtr = + std::shared_ptr; + } // namespace facebook::velox::substrait diff --git a/velox/substrait/tests/CMakeLists.txt b/velox/substrait/tests/CMakeLists.txt index d14925e012a..a455195d328 100644 --- a/velox/substrait/tests/CMakeLists.txt +++ b/velox/substrait/tests/CMakeLists.txt @@ -14,9 +14,15 @@ add_executable( velox_plan_conversion_test - VeloxToSubstraitTypeTest.cpp Substrait2VeloxPlanConversionTest.cpp - Substrait2VeloxValuesNodeConversionTest.cpp FunctionTest.cpp - JsonToProtoConverter.cpp VeloxSubstraitRoundTripPlanConverterTest.cpp) + VeloxToSubstraitTypeTest.cpp + Substrait2VeloxPlanConversionTest.cpp + Substrait2VeloxValuesNodeConversionTest.cpp + FunctionTest.cpp + JsonToProtoConverter.cpp + VeloxSubstraitRoundTripPlanConverterTest.cpp + VeloxToSubstraitTypeTest.cpp + VeloxSubstraitSignatureTest.cpp + SubstraitExtensionCollectorTest.cpp) add_dependencies(velox_plan_conversion_test velox_substrait_plan_converter) diff --git a/velox/substrait/tests/SubstraitExtensionCollectorTest.cpp b/velox/substrait/tests/SubstraitExtensionCollectorTest.cpp new file mode 100644 index 00000000000..81b6907d408 --- /dev/null +++ b/velox/substrait/tests/SubstraitExtensionCollectorTest.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/substrait/SubstraitExtensionCollector.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/core/PlanNode.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +using namespace facebook::velox; +using namespace facebook::velox::substrait; + +namespace facebook::velox::substrait::test { + +class SubstraitExtensionCollectorTest : public ::testing::Test { + protected: + void SetUp() override { + Test::SetUp(); + functions::prestosql::registerAllScalarFunctions(); + } + + int getReferenceNumber( + const std::string& functionName, + std::vector&& arguments) { + int referenceNumber1 = + extensionCollector_->getReferenceNumber(functionName, arguments); + // Repeat the call to make sure properly de-duplicated. + int referenceNumber2 = + extensionCollector_->getReferenceNumber(functionName, arguments); + EXPECT_EQ(referenceNumber1, referenceNumber2); + return referenceNumber1; + } + + int getReferenceNumber( + const std::string& functionName, + std::vector&& arguments, + core::AggregationNode::Step step) { + int referenceNumber1 = + extensionCollector_->getReferenceNumber(functionName, arguments, step); + // Repeat the call to make sure properly de-duplicated. + int referenceNumber2 = + extensionCollector_->getReferenceNumber(functionName, arguments, step); + EXPECT_EQ(referenceNumber1, referenceNumber2); + return referenceNumber2; + } + + /// Given a substrait plan, return the sorted extension functions by the + /// function anchor. + ::google::protobuf::RepeatedPtrField< + ::substrait::extensions::SimpleExtensionDeclaration> + getSortedSubstraitExtension(const ::substrait::Plan* substraitPlan) { + auto substraitExtensions = substraitPlan->extensions(); + std::sort( + substraitExtensions.begin(), + substraitExtensions.end(), + [](const ::substrait::extensions::SimpleExtensionDeclaration a, + const ::substrait::extensions::SimpleExtensionDeclaration b) { + return a.extension_function().function_anchor() < + b.extension_function().function_anchor(); + }); + + return substraitExtensions; + } + + SubstraitExtensionCollectorPtr extensionCollector_ = + std::make_shared(); +}; + +TEST_F(SubstraitExtensionCollectorTest, getReferenceNumberForScalarFunction) { + ASSERT_EQ(getReferenceNumber("plus", {INTEGER(), INTEGER()}), 0); + ASSERT_EQ(getReferenceNumber("divide", {INTEGER(), INTEGER()}), 1); + ASSERT_EQ(getReferenceNumber("cardinality", {ARRAY(INTEGER())}), 2); + ASSERT_EQ(getReferenceNumber("array_sum", {ARRAY(INTEGER())}), 3); + + auto functionType = std::make_shared( + std::vector{INTEGER(), VARCHAR()}, BIGINT()); + std::vector types = {MAP(INTEGER(), VARCHAR()), functionType}; + ASSERT_ANY_THROW(getReferenceNumber("transform_keys", std::move(types))); +} + +TEST_F( + SubstraitExtensionCollectorTest, + getReferenceNumberForAggregateFunction) { + // Sum aggregate function have same argument type for each aggregation step. + ASSERT_EQ( + getReferenceNumber( + "sum", {INTEGER()}, core::AggregationNode::Step::kSingle), + 0); + + // Partial avg aggregate function should use primitive integral type. + ASSERT_EQ( + getReferenceNumber( + "avg", {INTEGER()}, core::AggregationNode::Step::kPartial), + 1); + + // Final avg aggregate function should use struct type, like + // 'ROW' + ASSERT_EQ( + getReferenceNumber( + "avg", + {ROW({DOUBLE(), BIGINT()})}, + core::AggregationNode::Step::kFinal), + 2); + + // Count aggregate function have same argument type for each aggregation step. + ASSERT_EQ( + getReferenceNumber( + "count", {INTEGER()}, core::AggregationNode::Step::kFinal), + 3); +} + +TEST_F(SubstraitExtensionCollectorTest, addExtensionsToPlan) { + getReferenceNumber("plus", {INTEGER(), INTEGER()}); + getReferenceNumber("divide", {INTEGER(), INTEGER()}); + getReferenceNumber("cardinality", {ARRAY(INTEGER())}); + getReferenceNumber("array_sum", {ARRAY(INTEGER())}); + getReferenceNumber("sum", {INTEGER()}); + getReferenceNumber("avg", {INTEGER()}); + getReferenceNumber("avg", {ROW({DOUBLE(), BIGINT()})}); + getReferenceNumber("count", {INTEGER()}); + + google::protobuf::Arena arena; + auto* substraitPlan = + google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena); + + extensionCollector_->addExtensionsToPlan(substraitPlan); + + const auto& substraitExtensions = getSortedSubstraitExtension(substraitPlan); + auto getFunctionName = [&](auto id) { + return substraitExtensions[id].extension_function().name(); + }; + + ASSERT_EQ(substraitPlan->extensions().size(), 8); + ASSERT_EQ(getFunctionName(0), "plus:i32_i32"); + ASSERT_EQ(getFunctionName(1), "divide:i32_i32"); + ASSERT_EQ(getFunctionName(2), "cardinality:list"); + ASSERT_EQ(getFunctionName(3), "array_sum:list"); + ASSERT_EQ(getFunctionName(4), "sum:i32"); + ASSERT_EQ(getFunctionName(5), "avg:i32"); + ASSERT_EQ(getFunctionName(6), "avg:struct"); + ASSERT_EQ(getFunctionName(7), "count:i32"); +} + +} // namespace facebook::velox::substrait::test diff --git a/velox/substrait/tests/VeloxSubstraitSignatureTest.cpp b/velox/substrait/tests/VeloxSubstraitSignatureTest.cpp new file mode 100644 index 00000000000..c651cd793da --- /dev/null +++ b/velox/substrait/tests/VeloxSubstraitSignatureTest.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/substrait/VeloxSubstraitSignature.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +using namespace facebook::velox; +using namespace facebook::velox::substrait; + +namespace facebook::velox::substrait::test { + +class VeloxSubstraitSignatureTest : public ::testing::Test { + protected: + void SetUp() override { + Test::SetUp(); + functions::prestosql::registerAllScalarFunctions(); + } + + static std::string toSubstraitSignature(const TypePtr& type) { + return VeloxSubstraitSignature::toSubstraitSignature(type->kind()); + } + + static std::string toSubstraitSignature( + const std::string& functionName, + const std::vector& arguments) { + return VeloxSubstraitSignature::toSubstraitSignature( + functionName, arguments); + } +}; + +TEST_F(VeloxSubstraitSignatureTest, toSubstraitSignatureWithType) { + ASSERT_EQ(toSubstraitSignature(BOOLEAN()), "bool"); + + ASSERT_EQ(toSubstraitSignature(TINYINT()), "i8"); + ASSERT_EQ(toSubstraitSignature(SMALLINT()), "i16"); + ASSERT_EQ(toSubstraitSignature(INTEGER()), "i32"); + ASSERT_EQ(toSubstraitSignature(BIGINT()), "i64"); + ASSERT_EQ(toSubstraitSignature(REAL()), "fp32"); + ASSERT_EQ(toSubstraitSignature(DOUBLE()), "fp64"); + ASSERT_EQ(toSubstraitSignature(VARCHAR()), "str"); + ASSERT_EQ(toSubstraitSignature(VARBINARY()), "vbin"); + ASSERT_EQ(toSubstraitSignature(TIMESTAMP()), "ts"); + ASSERT_EQ(toSubstraitSignature(DATE()), "date"); + ASSERT_EQ(toSubstraitSignature(SHORT_DECIMAL(18, 2)), "dec"); + ASSERT_EQ(toSubstraitSignature(LONG_DECIMAL(18, 2)), "dec"); + ASSERT_EQ(toSubstraitSignature(ARRAY(BOOLEAN())), "list"); + ASSERT_EQ(toSubstraitSignature(ARRAY(INTEGER())), "list"); + ASSERT_EQ(toSubstraitSignature(MAP(INTEGER(), BIGINT())), "map"); + ASSERT_EQ(toSubstraitSignature(ROW({INTEGER(), BIGINT()})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({ARRAY(INTEGER())})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({MAP(INTEGER(), INTEGER())})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({ROW({INTEGER()})})), "struct"); + ASSERT_EQ(toSubstraitSignature(UNKNOWN()), "u!name"); + + ASSERT_ANY_THROW(toSubstraitSignature(INTERVAL_DAY_TIME())); +} + +TEST_F( + VeloxSubstraitSignatureTest, + toSubstraitSignatureWithFunctionNameAndArguments) { + ASSERT_EQ(toSubstraitSignature("eq", {INTEGER(), INTEGER()}), "eq:i32_i32"); + ASSERT_EQ(toSubstraitSignature("gt", {INTEGER(), INTEGER()}), "gt:i32_i32"); + ASSERT_EQ(toSubstraitSignature("lt", {INTEGER(), INTEGER()}), "lt:i32_i32"); + ASSERT_EQ(toSubstraitSignature("gte", {INTEGER(), INTEGER()}), "gte:i32_i32"); + ASSERT_EQ(toSubstraitSignature("lte", {INTEGER(), INTEGER()}), "lte:i32_i32"); + + ASSERT_EQ( + toSubstraitSignature("and", {BOOLEAN(), BOOLEAN()}), "and:bool_bool"); + ASSERT_EQ(toSubstraitSignature("or", {BOOLEAN(), BOOLEAN()}), "or:bool_bool"); + ASSERT_EQ(toSubstraitSignature("not", {BOOLEAN()}), "not:bool"); + ASSERT_EQ( + toSubstraitSignature("xor", {BOOLEAN(), BOOLEAN()}), "xor:bool_bool"); + + ASSERT_EQ( + toSubstraitSignature("between", {INTEGER(), INTEGER(), INTEGER()}), + "between:i32_i32_i32"); + + ASSERT_EQ( + toSubstraitSignature("plus", {INTEGER(), INTEGER()}), "plus:i32_i32"); + ASSERT_EQ( + toSubstraitSignature("divide", {INTEGER(), INTEGER()}), "divide:i32_i32"); + + ASSERT_EQ( + toSubstraitSignature("cardinality", {ARRAY(INTEGER())}), + "cardinality:list"); + ASSERT_EQ( + toSubstraitSignature("array_sum", {ARRAY(INTEGER())}), "array_sum:list"); + + ASSERT_EQ(toSubstraitSignature("sum", {INTEGER()}), "sum:i32"); + ASSERT_EQ(toSubstraitSignature("avg", {INTEGER()}), "avg:i32"); + ASSERT_EQ(toSubstraitSignature("count", {INTEGER()}), "count:i32"); + + auto functionType = std::make_shared( + std::vector{INTEGER(), VARCHAR()}, BIGINT()); + std::vector types = {MAP(INTEGER(), VARCHAR()), functionType}; + ASSERT_ANY_THROW(toSubstraitSignature("transform_keys", std::move(types))); +} + +} // namespace facebook::velox::substrait::test