diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java index 0fa0f13d0304a..e9be50cd3a14d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java @@ -841,44 +841,47 @@ protected RowExpression visitNullIfExpression(NullIfExpression node, Context con { RowExpression first = process(node.getFirst(), context); RowExpression second = process(node.getSecond(), context); + Type returnType = getType(node); - if (isNative && !second.getType().equals(first.getType())) { - Optional commonType = functionAndTypeResolver.getCommonSuperType(first.getType(), second.getType()); - if (!commonType.isPresent()) { - throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", first.getType(), second.getType()); - } - - Type returnType = getType(node); + if (isNative) { // If the first type is unknown, as per presto's NULL_IF semantics we should not infer the type using second argument. // Always return a null with unknown type. if (first.getType().equals(UnknownType.UNKNOWN)) { return constantNull(UnknownType.UNKNOWN); } - RowExpression originalFirst = first; - // cast(first as ) - if (!first.getType().equals(commonType.get())) { - first = call( - getSourceLocation(node), - CAST.name(), - functionAndTypeResolver.lookupCast(CAST.name(), first.getType(), commonType.get()), - commonType.get(), first); - } - // cast(second as ) - if (!second.getType().equals(commonType.get())) { - second = call( - getSourceLocation(node), - CAST.name(), - functionAndTypeResolver.lookupCast(CAST.name(), second.getType(), commonType.get()), - commonType.get(), second); + RowExpression firstArgWithoutCast = first; + + if (!second.getType().equals(first.getType())) { + Optional commonType = functionAndTypeResolver.getCommonSuperType(first.getType(), second.getType()); + if (!commonType.isPresent()) { + throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", first.getType(), second.getType()); + } + + // cast(first as ) + if (!first.getType().equals(commonType.get())) { + first = call( + getSourceLocation(node), + CAST.name(), + functionAndTypeResolver.lookupCast(CAST.name(), first.getType(), commonType.get()), + commonType.get(), first); + } + // cast(second as ) + if (!second.getType().equals(commonType.get())) { + second = call( + getSourceLocation(node), + CAST.name(), + functionAndTypeResolver.lookupCast(CAST.name(), second.getType(), commonType.get()), + commonType.get(), second); + } } FunctionHandle equalsFunctionHandle = functionAndTypeResolver.resolveOperator(EQUAL, fromTypes(first.getType(), second.getType())); // equal(cast(first as ), cast(second as )) RowExpression equal = call(EQUAL.name(), equalsFunctionHandle, BOOLEAN, first, second); // if (equal(cast(first as ), cast(second as )), cast(null as firstType), first) - return specialForm(IF, returnType, equal, constantNull(originalFirst.getType()), originalFirst); + return specialForm(IF, returnType, equal, constantNull(returnType), firstArgWithoutCast); } - return specialForm(getSourceLocation(node), NULL_IF, getType(node), first, second); + return specialForm(getSourceLocation(node), NULL_IF, returnType, first, second); } @Override diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index b33c64f10cfae..1bff719f4174f 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -611,22 +611,6 @@ TypedExprPtr convertDereferenceExpr( return std::make_shared(returnType, input, childName); } - -TypedExprPtr convertNullIfExpr( - const velox::TypePtr& returnType, - const std::vector& args) { - VELOX_CHECK_EQ(args.size(), 2); - - // Convert nullif(a, b) to if(a = b, null, a). - - std::vector newArgs = { - std::make_shared( - velox::BOOLEAN(), args, "presto.default.eq"), - std::make_shared( - returnType, velox::variant::null(returnType->kind())), - args[0]}; - return std::make_shared(returnType, newArgs, "if"); -} } // namespace TypedExprPtr VeloxExprConverter::toVeloxExpr( @@ -657,7 +641,7 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( } if (pexpr->form == protocol::Form::NULL_IF) { - return convertNullIfExpr(returnType, args); + VELOX_UNREACHABLE("NULL_IF not supported in specialForm") } auto form = std::string(json(pexpr->form)); diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java index c0829e0035046..6d9fdd2d5f019 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java @@ -263,7 +263,7 @@ public void testTopN() assertQuery("SELECT linenumber, NULL FROM lineitem ORDER BY 1 LIMIT 23"); } - @Test (enabled = false) + @Test public void testNullIf() { assertQuery("SELECT NULLIF(totalprice, 0) FROM (SELECT SUM(extendedprice) AS totalprice FROM lineitem WHERE shipdate >= '1995-09-01')"); @@ -963,7 +963,7 @@ public void testInsertIntoSpecialPartitionName() // For special character in partition name, without correct handling, it would throw errors like 'Invalid partition spec: nationkey=A/B' // In this test, verify those partition names can be successfully created - String[] specialCharacters = new String[]{"\"", "#", "%", "''", "*", "/", ":", "=", "?", "\\", "\\x7F", "{", "[", "]", "^"}; // escape single quote for sql + String[] specialCharacters = new String[] {"\"", "#", "%", "''", "*", "/", ":", "=", "?", "\\", "\\x7F", "{", "[", "]", "^"}; // escape single quote for sql for (String specialCharacter : specialCharacters) { getQueryRunner().execute(writeSession, String.format("INSERT INTO %s VALUES ('name', 'A%sB')", tmpTableName, specialCharacter)); assertQuery(String.format("SELECT nationkey FROM %s", tmpTableName), String.format("VALUES('A%sB')", specialCharacter)); diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java index 83a28231b67f4..8937a56e47a3f 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java @@ -140,6 +140,7 @@ public static QueryRunner createNativeQueryRunner( ImmutableMap.builder() .put("http-server.http.port", "8080") .put("experimental.internal-communication.thrift-transport-enabled", String.valueOf(useThrift)) + .put("native-execution-enabled", "true") .putAll(getNativeWorkerSystemProperties()) .build(), ImmutableMap.of(), diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java index ff624b2874c1e..c5bb8acefb93c 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java @@ -84,13 +84,8 @@ public void testMapOnlyQueries() { assertQuery("SELECT * FROM orders"); assertQuery("SELECT orderkey, custkey FROM orders WHERE orderkey <= 200"); - assertQuery("SELECT orderkey, custkey FROM orders ORDER BY orderkey LIMIT 4"); - } - - @Test (enabled = false) - public void testNullIf() - { assertQuery("SELECT nullif(orderkey, custkey) FROM orders"); + assertQuery("SELECT orderkey, custkey FROM orders ORDER BY orderkey LIMIT 4"); } @Test