diff --git a/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp index aa8a228360c98..91e1c309d8a7e 100644 --- a/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp @@ -17,6 +17,7 @@ #include "presto_cpp/main/types/PrestoToVeloxExpr.h" #include "velox/core/ITypedExpr.h" #include "velox/expression/ExprConstants.h" +#include "velox/vector/BaseVector.h" #include "velox/vector/ConstantVector.h" using namespace facebook::presto; @@ -161,6 +162,78 @@ VeloxToPrestoExprConverter::getSwitchSpecialFormExpressionArgs( return result; } +void VeloxToPrestoExprConverter::getArgsFromConstantInList( + const velox::core::ConstantTypedExpr* inList, + std::vector& result) const { + const auto inListVector = inList->toConstantVector(pool_); + auto* constantVector = + inListVector->as>(); + VELOX_CHECK_NOT_NULL( + constantVector, "Expected ConstantVector of Array type for IN-list."); + const auto* arrayVector = + constantVector->wrappedVector()->as(); + VELOX_CHECK_NOT_NULL( + arrayVector, + "Expected constant IN-list to be of Array type, but got {}.", + constantVector->wrappedVector()->type()->toString()); + + auto wrappedIdx = constantVector->wrappedIndex(0); + auto size = arrayVector->sizeAt(wrappedIdx); + auto offset = arrayVector->offsetAt(wrappedIdx); + auto elementsVector = arrayVector->elements(); + + for (velox::vector_size_t i = 0; i < size; i++) { + auto elementIndex = offset + i; + auto elementConstant = + velox::BaseVector::wrapInConstant(1, elementIndex, elementsVector); + // Construct a core::ConstantTypedExpr from the constant value at this + // index in array vector, then convert it to a protocol::RowExpression. + const auto constantExpr = + std::make_shared(elementConstant); + result.push_back(getConstantExpression(constantExpr.get())); + } +} + +// IN expression in Presto is of form `expr0 IN [expr1, expr2, ..., exprN]`. +// The Velox representation of IN expression has the same form as Presto when +// any of the expressions in the IN list is non-constant; when the IN list only +// has constant expressions, it is of form `expr0 IN constantExpr(ARRAY[ +// expr1.constantValue(), expr2.constantValue(), ..., exprN.constantValue()])`. +// This function retrieves the arguments to Presto IN expression from Velox IN +// expression in both of these forms. +std::vector +VeloxToPrestoExprConverter::getInSpecialFormExpressionArgs( + const velox::core::CallTypedExpr* inExpr) const { + std::vector result; + const auto& inputs = inExpr->inputs(); + const auto numInputs = inputs.size(); + VELOX_CHECK_GE(numInputs, 2, "IN expression should have at least 2 inputs"); + + // Value being searched for with this `IN` expression is always the first + // input, convert it to a Presto expression. + result.push_back(getRowExpression(inputs.at(0))); + const auto& inList = inputs.at(1); + if (numInputs == 2 && inList->isConstantKind()) { + // Converts inputs from constant Velox IN-list to arguments in the Presto + // `IN` expression. Eg: For expression `col0 IN ['apple', 'foo', `bar`]`, + // `apple`, `foo`, and `bar` from the IN-list are converted to equivalent + // Presto constant expressions. + const auto* constantInList = + inList->asUnchecked(); + getArgsFromConstantInList(constantInList, result); + } else { + // Converts inputs from the Velox IN-list to arguments in the Presto `IN` + // expression when the Velox IN-list has at least one non-constant + // expression. Eg: For expression `col0 IN ['apple', col1, 'foo']`, `apple`, + // col1, and `foo` from the IN-list are converted to equivalent + // Presto expressions. + for (auto i = 1; i < numInputs; i++) { + result.push_back(getRowExpression(inputs[i])); + } + } + return result; +} + SpecialFormExpressionPtr VeloxToPrestoExprConverter::getSpecialFormExpression( const velox::core::CallTypedExpr* expr) const { VELOX_CHECK( @@ -181,11 +254,14 @@ SpecialFormExpressionPtr VeloxToPrestoExprConverter::getSpecialFormExpression( // Arguments for switch expression include 'WHEN' special form expression(s) // so they are constructed separately. static constexpr char const* kSwitch = "SWITCH"; + static constexpr char const* kIn = "IN"; if (name == kSwitch) { result.arguments = getSwitchSpecialFormExpressionArgs(expr); + } else if (name == kIn) { + result.arguments = getInSpecialFormExpressionArgs(expr); } else { - // Presto special form expressions that are not of type `SWITCH`, such as - // `IN`, `AND`, `OR` etc,. are handled in this clause. The list of Presto + // Presto special form expressions that are not of type `SWITCH` and `IN`, + // such as `AND`, `OR`, are handled in this clause. The list of Presto // special form expressions can be found in `kPrestoSpecialForms` in the // helper function `isPrestoSpecialForm`. auto exprInputs = expr->inputs(); diff --git a/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h index 08e3de660e0ae..26369c7ec382e 100644 --- a/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h +++ b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h @@ -81,6 +81,17 @@ class VeloxToPrestoExprConverter { std::vector getSwitchSpecialFormExpressionArgs( const velox::core::CallTypedExpr* switchExpr) const; + /// Helper function to convert values from a constant `IN` list in Velox + /// expression to equivalent Presto expressions. + void getArgsFromConstantInList( + const velox::core::ConstantTypedExpr* inList, + std::vector& result) const; + + /// Helper function to get the arguments for Presto `IN` expression from + /// Velox `IN` expression. + std::vector getInSpecialFormExpressionArgs( + const velox::core::CallTypedExpr* inExpr) const; + /// Helper function to construct a Presto `protocol::SpecialFormExpression` /// from a Velox call expression. This function should be called only on call /// expressions that map to a Presto `SpecialFormExpression`. This can be diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java index 49cb71e9ffaea..f325a315fd8e1 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java @@ -653,7 +653,7 @@ public void testP4HyperLogLogWithApproxSet() // are addressed using the native expression optimizer, and it is enabled everywhere. @Test - public void testQueriesUsingNativeOptimizer() + public void testNativeExpressionOptimizer() { Session session = Session.builder(getSession()) .setSystemProperty(EXPRESSION_OPTIMIZER_NAME, "native") @@ -681,6 +681,14 @@ public void testQueriesUsingNativeOptimizer() assertQuerySucceeds(session, "SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey, orderstatus) rn, * from orders) WHERE rn = 1"); assertQuerySucceeds(session, "WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) SELECT * FROM t WHERE rn = 1"); assertQuerySucceeds(session, "SELECT row_number() OVER (PARTITION BY orderdate ORDER BY orderdate) FROM orders"); + + // IN expressions + assertQuerySucceeds(session, "SELECT table_name FROM information_schema.columns WHERE table_name IN ('nation', 'region')"); + assertQuerySucceeds(session, "SELECT name FROM nation WHERE nationkey NOT IN (1, 2, 3, 4, 5, 10, 11, 12, 13)"); + assertQuerySucceeds(session, "SELECT orderkey FROM lineitem WHERE shipmode IN ('TRUCK', 'FOB', 'RAIL')"); + assertQuerySucceeds(session, "SELECT table_name, COALESCE(abs(ordinal_position), 0) as abs_pos FROM information_schema.columns WHERE table_catalog = 'hive' AND table_name IN ('nation', 'region') ORDER BY table_name, ordinal_position"); + assertQuerySucceeds(session, "SELECT table_name, ordinal_position FROM information_schema.columns WHERE abs(ordinal_position) IN (1, 2, 3) AND table_catalog = 'hive' AND table_name != 'roles' ORDER BY table_name, ordinal_position"); + assertQuerySucceeds(session, "select lower(table_name) from information_schema.tables where table_name = 'lineitem' or table_name = 'LINEITEM'"); } private String generateRandomTableName()