diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java index 5123d1cd0787d..997f94ab639ca 100644 --- a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java +++ b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java @@ -199,7 +199,7 @@ protected RowExpression toRowExpression(Expression expression, Session session) expression, ImmutableMap.of(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, expressionTypes, ImmutableMap.of(), functionAndTypeManager.getFunctionAndTypeResolver(), session); + return SqlToRowExpressionTranslator.translate(expression, expressionTypes, ImmutableMap.of(), functionAndTypeManager, session); } protected LimitNode limit(PlanBuilder pb, long count, PlanNode source) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java index 3bd1e0346b3a5..edeb6898a8838 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java @@ -759,7 +759,7 @@ private RowExpression convertToRowExpression(Expression expression, Scope scope) coercedMaybe, coercedExpressionAnalysis.getExpressionTypes(), ImmutableMap.of(), - metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + metadata.getFunctionAndTypeManager(), session); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 16597000e5ff8..8f8669042d0e0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -1495,7 +1495,7 @@ private MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName mat viewQueryWhereClause, analysis.getTypes(), ImmutableMap.of(), - functionAndTypeResolver, + metadata.getFunctionAndTypeManager(), session); TupleDomain viewQueryDomain = MaterializedViewUtils.getDomainFromFilter(session, domainTranslator, rowExpression); @@ -2099,11 +2099,7 @@ private List analyzeWindowFunctions(QuerySpecification node, List< } Window window = windowFunction.getWindow().get(); - if (window.getOrderBy().filter( - orderBy -> orderBy.getSortItems() - .stream() - .anyMatch(item -> item.getSortKey() instanceof Literal)) - .isPresent()) { + if (window.getOrderBy().filter(orderBy -> orderBy.getSortItems().stream().anyMatch(item -> item.getSortKey() instanceof Literal)).isPresent()) { if (isAllowWindowOrderByLiterals(session)) { warningCollector.add( new PrestoWarning( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java index a5457fd60409e..6b7ed5b5ad505 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java @@ -144,7 +144,7 @@ public BytecodeNode visitCall(CallExpression call, Context context) RowExpression function = getSqlFunctionRowExpression( functionMetadata, functionImplementation, - metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + metadata.getFunctionAndTypeManager(), sqlFunctionProperties, sessionFunctions, call.getArguments()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java index e2c97fb80ac64..a4ee88c1a327e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java @@ -282,7 +282,7 @@ else if (implementationType.equals(JAVA)) { RowExpression function = getSqlFunctionRowExpression( functionMetadata, functionImplementation, - metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + metadata.getFunctionAndTypeManager(), session.getSqlFunctionProperties(), session.getSessionFunctions(), node.getArguments()); @@ -582,10 +582,10 @@ else if (!found && result) { if (hasUnresolvedValue) { List simplifiedExpressionValues = Stream.concat( - Stream.concat( - Stream.of(toRowExpression(target, node.getArguments().get(0))), - unresolvedValues.stream().filter(determinismEvaluator::isDeterministic).distinct()), - unresolvedValues.stream().filter((expression -> !determinismEvaluator.isDeterministic(expression)))) + Stream.concat( + Stream.of(toRowExpression(target, node.getArguments().get(0))), + unresolvedValues.stream().filter(determinismEvaluator::isDeterministic).distinct()), + unresolvedValues.stream().filter((expression -> !determinismEvaluator.isDeterministic(expression)))) .collect(toImmutableList()); return new SpecialFormExpression(IN, node.getType(), simplifiedExpressionValues); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslateExpressionsUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslateExpressionsUtil.java index df56342d3d8df..386295db8fcc9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslateExpressionsUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslateExpressionsUtil.java @@ -74,7 +74,7 @@ public static RowExpression toRowExpression(Expression expression, Metadata meta public static RowExpression toRowExpression(Expression expression, Metadata metadata, Session session, Map, Type> types, SqlToRowExpressionTranslator.Context context) { - return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), session, context); + return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), session, context); } public static Map, Type> analyzeCallExpressionTypes( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java index 6b4bba2935ca4..4f305f69b8530 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java @@ -106,7 +106,7 @@ public RowExpression rewriteCall(CallExpression expression, Void context, RowExp return getSqlFunctionRowExpression( functionMetadata, (SqlInvokedScalarFunctionImplementation) metadata.getFunctionAndTypeManager().getScalarFunctionImplementation(functionHandle), - metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + metadata.getFunctionAndTypeManager(), session.getSqlFunctionProperties(), session.getSessionFunctions(), rewrittenArguments); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java index 6ac0c5ab24e38..3bd1404a9eab2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java @@ -119,7 +119,7 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma WarningCollector.NOOP); // convert to row expression - return translate(expression, expressionTypes, variableInput, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), session); + return translate(expression, expressionTypes, variableInput, metadata.getFunctionAndTypeManager(), session); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java index 0e6952a1d4ede..43c212cb0445f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.SqlFunctionId; @@ -82,31 +83,34 @@ public static Expression getSqlFunctionExpression( public static RowExpression getSqlFunctionRowExpression( FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation implementation, - FunctionAndTypeResolver functionAndTypeResolver, + FunctionAndTypeManager functionAndTypeManager, SqlFunctionProperties sqlFunctionProperties, Map sessionFunctions, List arguments) { VariableAllocator variableAllocator = new VariableAllocator(); - Map argumentVariables = allocateFunctionArgumentVariables(functionMetadata, functionAndTypeResolver, variableAllocator); - Expression expression = getSqlFunctionImplementationExpression(functionMetadata, implementation, functionAndTypeResolver, variableAllocator, sqlFunctionProperties, argumentVariables); + Map argumentVariables = allocateFunctionArgumentVariables(functionMetadata, functionAndTypeManager.getFunctionAndTypeResolver(), variableAllocator); + Expression expression = getSqlFunctionImplementationExpression(functionMetadata, implementation, functionAndTypeManager.getFunctionAndTypeResolver(), variableAllocator, sqlFunctionProperties, argumentVariables); // Translate to row expression return SqlFunctionArgumentBinder.bindFunctionArguments( SqlToRowExpressionTranslator.translate( expression, analyzeSqlFunctionExpression( - functionAndTypeResolver, + functionAndTypeManager.getFunctionAndTypeResolver(), sqlFunctionProperties, expression, argumentVariables.values().stream() .collect(toImmutableMap(VariableReferenceExpression::getName, VariableReferenceExpression::getType))).getExpressionTypes(), ImmutableMap.of(), - functionAndTypeResolver, + functionAndTypeManager, Optional.empty(), Optional.empty(), sqlFunctionProperties, sessionFunctions, + // TODO: use session to determine if this is a native query + // https://github.com/prestodb/presto/issues/20008 + false, new SqlToRowExpressionTranslator.Context()), functionMetadata.getArgumentNames().get(), arguments, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java index 65f7223f128cf..0fa0f13d0304a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.relational; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.transaction.TransactionId; @@ -27,6 +28,8 @@ import com.facebook.presto.common.type.TypeWithName; import com.facebook.presto.common.type.UnknownType; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.relation.ConstantExpression; @@ -139,6 +142,7 @@ import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.constant; @@ -171,14 +175,14 @@ public static RowExpression translate( Expression expression, Map, Type> types, Map layout, - FunctionAndTypeResolver functionAndTypeResolver, + FunctionAndTypeManager functionAndTypeManager, Session session) { return translate( expression, types, layout, - functionAndTypeResolver, + functionAndTypeManager, session, new Context()); } @@ -187,7 +191,7 @@ public static RowExpression translate( Expression expression, Map, Type> types, Map layout, - FunctionAndTypeResolver functionAndTypeResolver, + FunctionAndTypeManager functionAndTypeManager, Session session, Context context) { @@ -195,11 +199,12 @@ public static RowExpression translate( expression, types, layout, - functionAndTypeResolver, + functionAndTypeManager, Optional.of(session.getUser()), session.getTransactionId(), session.getSqlFunctionProperties(), session.getSessionFunctions(), + SystemSessionProperties.isNativeExecutionEnabled(session), context); } @@ -207,21 +212,23 @@ public static RowExpression translate( Expression expression, Map, Type> types, Map layout, - FunctionAndTypeResolver functionAndTypeResolver, + FunctionAndTypeManager functionAndTypeManager, Optional user, Optional transactionId, SqlFunctionProperties sqlFunctionProperties, Map sessionFunctions, + boolean isNative, Context context) { Visitor visitor = new Visitor( types, layout, - functionAndTypeResolver, + functionAndTypeManager, user, transactionId, sqlFunctionProperties, - sessionFunctions); + sessionFunctions, + isNative); RowExpression result = visitor.process(expression, context); requireNonNull(result, "translated expression is null"); return result; @@ -256,30 +263,35 @@ private static class Visitor { private final Map, Type> types; private final Map layout; + private final FunctionAndTypeManager functionAndTypeManager; private final FunctionAndTypeResolver functionAndTypeResolver; private final Optional user; private final Optional transactionId; private final SqlFunctionProperties sqlFunctionProperties; private final Map sessionFunctions; private final FunctionResolution functionResolution; + private final boolean isNative; private Visitor( Map, Type> types, Map layout, - FunctionAndTypeResolver functionAndTypeResolver, + FunctionAndTypeManager functionAndTypeManager, Optional user, Optional transactionId, SqlFunctionProperties sqlFunctionProperties, - Map sessionFunctions) + Map sessionFunctions, + boolean isNative) { this.types = requireNonNull(types, "types is null"); - this.layout = layout; - this.functionAndTypeResolver = functionAndTypeResolver; - this.user = user; - this.transactionId = transactionId; - this.sqlFunctionProperties = sqlFunctionProperties; + this.layout = requireNonNull(layout); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager); + this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver(); + this.user = requireNonNull(user); + this.transactionId = requireNonNull(transactionId); + this.sqlFunctionProperties = requireNonNull(sqlFunctionProperties); this.functionResolution = new FunctionResolution(functionAndTypeResolver); - this.sessionFunctions = sessionFunctions; + this.sessionFunctions = requireNonNull(sessionFunctions); + this.isNative = isNative; } private Type getType(Expression node) @@ -830,6 +842,42 @@ protected RowExpression visitNullIfExpression(NullIfExpression node, Context con RowExpression first = process(node.getFirst(), context); RowExpression second = process(node.getSecond(), context); + if (isNative && !second.getType().equals(first.getType())) { + Optional commonType = functionAndTypeResolver.getCommonSuperType(first.getType(), second.getType()); + if (!commonType.isPresent()) { + throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", first.getType(), second.getType()); + } + + Type returnType = getType(node); + // If the first type is unknown, as per presto's NULL_IF semantics we should not infer the type using second argument. + // Always return a null with unknown type. + if (first.getType().equals(UnknownType.UNKNOWN)) { + return constantNull(UnknownType.UNKNOWN); + } + RowExpression originalFirst = first; + // cast(first as ) + if (!first.getType().equals(commonType.get())) { + first = call( + getSourceLocation(node), + CAST.name(), + functionAndTypeResolver.lookupCast(CAST.name(), first.getType(), commonType.get()), + commonType.get(), first); + } + // cast(second as ) + if (!second.getType().equals(commonType.get())) { + second = call( + getSourceLocation(node), + CAST.name(), + functionAndTypeResolver.lookupCast(CAST.name(), second.getType(), commonType.get()), + commonType.get(), second); + } + FunctionHandle equalsFunctionHandle = functionAndTypeResolver.resolveOperator(EQUAL, fromTypes(first.getType(), second.getType())); + // equal(cast(first as ), cast(second as )) + RowExpression equal = call(EQUAL.name(), equalsFunctionHandle, BOOLEAN, first, second); + + // if (equal(cast(first as ), cast(second as )), cast(null as firstType), first) + return specialForm(IF, returnType, equal, constantNull(originalFirst.getType()), originalFirst); + } return specialForm(getSourceLocation(node), NULL_IF, getType(node), first, second); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/CustomFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/CustomFunctions.java index 7800f94cea768..e3243114b77f6 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/CustomFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/CustomFunctions.java @@ -14,9 +14,13 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; import com.facebook.presto.spi.function.SqlType; import io.airlift.slice.Slice; @@ -45,4 +49,13 @@ public static boolean customIsNullBigint(@SqlNullable @SqlType(StandardTypes.BIG { return value == null; } + + @SqlInvokedScalarFunction(value = "custom_square", deterministic = true, calledOnNullInput = false) + @Description("Custom SQL to test NULLIF in Functions") + @SqlParameters({@SqlParameter(name = "x", type = "integer"), @SqlParameter(name = "y", type = "integer")}) + @SqlType("integer") + public static String customSquare() + { + return "RETURN IF(NULLIF(x, y) IS NOT NULL, x * x, y * y)"; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 08b28b8a42948..87e54df900b84 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -1020,7 +1020,7 @@ private static SourceOperatorFactory compileScanFilterProject(SqlFunctionPropert private RowExpression toRowExpression(Expression projection, Map, Type> expressionTypes, Map layout) { - return translate(projection, expressionTypes, layout, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), session); + return translate(projection, expressionTypes, layout, metadata.getFunctionAndTypeManager(), session); } private static Page getAtMostOnePage(Operator operator, Page sourcePage) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java index 5ab166f95bc8b..be335dd54c5c9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java @@ -13,19 +13,36 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.operator.scalar.annotations.SqlInvokedScalarFromAnnotationsParser; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.List; + import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; public class TestCustomFunctions extends AbstractTestFunctions { + public TestCustomFunctions() + { + } + + protected TestCustomFunctions(FeaturesConfig config) + { + super(config); + } + @BeforeClass public void setupClass() { registerScalar(CustomFunctions.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(CustomFunctions.class); + this.functionAssertions.addFunctions(functions); } @Test @@ -47,4 +64,11 @@ public void testLongIsNull() assertFunction("custom_is_null(CAST(NULL AS BIGINT))", BOOLEAN, true); assertFunction("custom_is_null(0)", BOOLEAN, false); } + + @Test + public void testNullIf() + { + assertFunction("custom_square(2, 5)", IntegerType.INTEGER, 4); + assertFunction("custom_square(5, 5)", IntegerType.INTEGER, 25); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java b/presto-main/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java index 163416a04eaed..485dbfbacdfc8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java @@ -268,7 +268,7 @@ private JsonCodec getJsonCodec() private RowExpression translate(Expression expression, boolean optimize) { - RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, getExpressionTypes(expression), ImmutableMap.of(), metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), TEST_SESSION); + RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, getExpressionTypes(expression), ImmutableMap.of(), metadata.getFunctionAndTypeManager(), TEST_SESSION); if (optimize) { RowExpressionOptimizer optimizer = new RowExpressionOptimizer(metadata); return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java index 1689b0bddabf8..fe2f86d8d92ff 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java @@ -74,13 +74,13 @@ public RowExpression translate(Expression expression, TypeProvider typeProvider) expression, getExpressionTypes(expression, typeProvider), ImmutableMap.of(), - metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + metadata.getFunctionAndTypeManager(), TEST_SESSION); } public RowExpression translateAndOptimize(Expression expression, Map, Type> types) { - RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), TEST_SESSION); + RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), TEST_SESSION); RowExpressionOptimizer optimizer = new RowExpressionOptimizer(metadata); return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java index 3745ef2770d1b..95788532f898e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java @@ -184,7 +184,7 @@ private RowExpression rowExpression(String value) Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyMap(), WarningCollector.NOOP); - RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver(), TEST_SESSION); + RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, METADATA.getFunctionAndTypeManager(), TEST_SESSION); RowExpressionOptimizer optimizer = new RowExpressionOptimizer(METADATA); return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java index 930e990e1ac90..3061e5625896c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java @@ -182,7 +182,7 @@ private RowExpression rowExpression(String value) Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyMap(), WarningCollector.NOOP); - RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver(), TEST_SESSION); + RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, METADATA.getFunctionAndTypeManager(), TEST_SESSION); RowExpressionOptimizer optimizer = new RowExpressionOptimizer(METADATA); return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCommonSubExpressionRewriter.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCommonSubExpressionRewriter.java index b358321947d27..2be033a220041 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCommonSubExpressionRewriter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCommonSubExpressionRewriter.java @@ -147,7 +147,7 @@ private RowExpression rowExpression(String sql) expression, expressionTypes, ImmutableMap.of(), - METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + METADATA.getFunctionAndTypeManager(), SESSION); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java index 6f5d062cabd43..dfd268a014ec7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java @@ -167,7 +167,7 @@ public void setupClass() else { executor = newDirectExecutorService(); } - functionAssertions = new FunctionAssertions(); + this.functionAssertions = setFunctionAssertions(); } @AfterClass(alwaysRun = true) @@ -195,6 +195,13 @@ public void tearDown(Method method) log.info("FINISHED %s in %s verified %s expressions", method.getName(), Duration.nanosSince(start), futures.size()); } + // This is a setter and not a test method. + // TestNG considers this as a test method since we have annotated the class with @Test. + public FunctionAssertions setFunctionAssertions() + { + return new FunctionAssertions(); + } + @Test public void smokedTest() throws Exception @@ -640,12 +647,12 @@ public void testNestedColumnFilter() // combination of types in one filter assertFilter( ImmutableList.of( - "bound_row.nested_column_0 = 1234", "bound_row.nested_column_7 >= 1234", - "bound_row.nested_column_1 = 34", "bound_row.nested_column_8 >= 33", - "bound_row.nested_column_2 = 'hello'", "bound_row.nested_column_9 >= 'hello'", - "bound_row.nested_column_3 = 12.34", "bound_row.nested_column_10 >= 12.34", - "bound_row.nested_column_4 = true", "NOT (bound_row.nested_column_11 = false)", - "bound_row.nested_column_6.nested_nested_column = 'innerFieldValue'", "bound_row.nested_column_13.nested_nested_column LIKE 'innerFieldValue'") + "bound_row.nested_column_0 = 1234", "bound_row.nested_column_7 >= 1234", + "bound_row.nested_column_1 = 34", "bound_row.nested_column_8 >= 33", + "bound_row.nested_column_2 = 'hello'", "bound_row.nested_column_9 >= 'hello'", + "bound_row.nested_column_3 = 12.34", "bound_row.nested_column_10 >= 12.34", + "bound_row.nested_column_4 = true", "NOT (bound_row.nested_column_11 = false)", + "bound_row.nested_column_6.nested_nested_column = 'innerFieldValue'", "bound_row.nested_column_13.nested_nested_column LIKE 'innerFieldValue'") .stream().collect(joining(" AND ")), true); } @@ -1606,6 +1613,9 @@ public void testCoalesce() public void testNullif() throws Exception { + assertExecute("nullif(BIGINT '2', INT '2')", BIGINT, null); + assertExecute("nullif(INT '2', BIGINT '2')", INTEGER, null); + assertExecute("nullif(INT '2', BIGINT '3')", INTEGER, 2); assertExecute("nullif(NULL, NULL)", UNKNOWN, null); assertExecute("nullif(NULL, 2)", UNKNOWN, null); assertExecute("nullif(2, NULL)", INTEGER, 2); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java index 1e4bbcf74753e..dc9ca779761db 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java @@ -183,7 +183,7 @@ public void testNoInlineIntoPlanWhenInlineIsDisabled() IntegerType.INTEGER); } - private void assertInlined(String inputExpressionStr, String expectedExpressionStr, String variable, Type type) + protected void assertInlined(String inputExpressionStr, String expectedExpressionStr, String variable, Type type) { RowExpression inputExpression = new TestingRowExpressionTranslator(tester.getMetadata()).translate(inputExpressionStr, ImmutableMap.of(variable, type)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 980c4c9acf73c..38b2dc40e4e9a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -986,7 +986,7 @@ private RowExpression rowExpression(Expression expression) expression, expressionTypes, ImmutableMap.of(), - metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), + metadata.getFunctionAndTypeManager(), session); } diff --git a/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java index 6c19da8d531d2..eb26a38db4e13 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java @@ -609,7 +609,7 @@ private RowExpression rowExpression(String value) Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, metadata, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyMap(), WarningCollector.NOOP); - RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), TEST_SESSION); + RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, metadata.getFunctionAndTypeManager(), TEST_SESSION); RowExpressionOptimizer optimizer = new RowExpressionOptimizer(metadata); return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession()); } diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml index 920ed19e95f0d..0d311d1773904 100644 --- a/presto-native-execution/pom.xml +++ b/presto-native-execution/pom.xml @@ -38,6 +38,13 @@ presto-main + + com.facebook.presto + presto-main + test-jar + test + + com.facebook.presto presto-common diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestExpressionCompiler.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestExpressionCompiler.java new file mode 100644 index 0000000000000..9ed6cbf0e9981 --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestExpressionCompiler.java @@ -0,0 +1,236 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker; + +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.gen.TestExpressionCompiler; +import com.facebook.presto.testing.QueryRunner; +import org.testng.annotations.Ignore; + +public abstract class AbstractTestExpressionCompiler + extends TestExpressionCompiler +{ + @Override + public FunctionAssertions setFunctionAssertions() + { + return new FunctionAssertions(getQueryRunner().getDefaultSession(), new FeaturesConfig().setNativeExecutionEnabled(true)); + } + + protected abstract QueryRunner getQueryRunner(); + + // TODO: The following test have trouble converting long to Decimal. + // https://github.com/prestodb/presto/issues/19999 + @Override + @Ignore + public void testBinaryOperatorsDecimalBigint() + throws Exception + { + } + + @Override + @Ignore + public void testBinaryOperatorsDecimalInteger() + throws Exception + { + } + + @Override + @Ignore + public void testBinaryOperatorsDecimalDouble() + throws Exception + { + } + + // Remove the override from the following tests on a need basis, not all expressions have custom handling for native query runner, hence they are ignored. + @Override + @Ignore + public void smokedTest() + { + } + + @Override + @Ignore + public void filterFunction() + { + } + + @Override + @Ignore + public void testUnaryOperators() + { + } + + @Override + @Ignore + public void testFilterEmptyInput() + { + } + + @Override + @Ignore + public void testNestedColumnFilter() + { + } + + @Override + @Ignore + public void testTernaryOperatorsLongLong() + { + } + + @Override + @Ignore + public void testTernaryOperatorsLongDouble() + { + } + + @Override + @Ignore + public void testTernaryOperatorsDoubleDouble() + { + } + + @Override + @Ignore + public void testTernaryOperatorsString() + { + } + + @Override + @Ignore + public void testTernaryOperatorsLongDecimal() + { + } + + @Override + @Ignore + public void testTernaryOperatorsDecimalDouble() + { + } + + @Override + @Ignore + public void testCast() + { + } + + @Override + @Ignore + public void testTryCast() + { + } + + @Override + @Ignore + public void testAnd() + { + } + + @Override + @Ignore + public void testOr() + { + } + + @Override + @Ignore + public void testNot() + { + } + + @Override + @Ignore + public void testIf() + { + } + + @Override + @Ignore + public void testSimpleCase() + { + } + + @Override + @Ignore + public void testSearchCaseSingle() + { + } + + @Override + @Ignore + public void testSearchCaseMultiple() + { + } + + @Override + @Ignore + public void testIn() + { + } + + @Override + @Ignore + public void testHugeIn() + { + } + + @Override + @Ignore + public void testInComplexTypes() + { + } + + @Override + @Ignore + public void testFunctionCall() + { + } + + @Override + @Ignore + public void testFunctionCallRegexp() + { + } + + @Override + @Ignore + public void testFunctionCallJson() + { + } + + @Override + @Ignore + public void testFunctionWithSessionCall() + { + } + + @Override + @Ignore + public void testExtract() + { + } + + @Override + @Ignore + public void testLike() + { + } + + @Override + @Ignore + public void testCoalesce() + { + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java index f10c32851ce4e..a62e1ab5f6a12 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java @@ -261,6 +261,19 @@ public void testTopN() assertQuery("SELECT linenumber, NULL FROM lineitem ORDER BY 1 LIMIT 23"); } + @Test + public void testNullIf() + { + assertQuery("SELECT NULLIF(totalprice, 0) FROM (SELECT SUM(extendedprice) AS totalprice FROM lineitem WHERE shipdate >= '1995-09-01')"); + assertQuery("SELECT NULLIF(totalprice, 0) FROM (SELECT SUM(extendedprice) AS totalprice FROM lineitem WHERE shipdate >= '9999-99-99')"); + assertQuery("SELECT NULLIF(totalprice, 0.5) FROM (SELECT SUM(extendedprice) AS totalprice FROM lineitem WHERE shipdate >= '1995-09-01')"); + assertQuery("SELECT NULLIF(totalprice, 0.5) FROM (SELECT SUM(extendedprice) AS totalprice FROM lineitem WHERE shipdate >= '9999-99-99')"); + assertQuery("SELECT NULLIF(totalprice, 0) FROM (SELECT COUNT(1) AS totalprice FROM lineitem WHERE shipdate >= '1995-09-01')"); + assertQuery("SELECT NULLIF(totalprice, 0) FROM (SELECT COUNT(1) AS totalprice FROM lineitem WHERE shipdate >= '9999-99-99')"); + assertQuery("SELECT NULLIF(totalprice, 0.5) FROM (SELECT COUNT(1) AS totalprice FROM lineitem WHERE shipdate >= '1995-09-01')"); + assertQuery("SELECT NULLIF(totalprice, 0.5) FROM (SELECT COUNT(1) AS totalprice FROM lineitem WHERE shipdate >= '9999-99-99')"); + } + @Test public void testCast() { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeGeneralQueriesJSON.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeGeneralQueriesJSON.java index b343bf38fee2b..a26963d12fe19 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeGeneralQueriesJSON.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeGeneralQueriesJSON.java @@ -20,13 +20,15 @@ public class TestPrestoNativeGeneralQueriesJSON extends AbstractTestNativeGeneralQueries { @Override - protected QueryRunner createQueryRunner() throws Exception + protected QueryRunner createQueryRunner() + throws Exception { return PrestoNativeQueryRunnerUtils.createNativeQueryRunner(false); } @Override - protected ExpectedQueryRunner createExpectedQueryRunner() throws Exception + protected ExpectedQueryRunner createExpectedQueryRunner() + throws Exception { return PrestoNativeQueryRunnerUtils.createJavaQueryRunner(); } diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkExpressionCompiler.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkExpressionCompiler.java new file mode 100644 index 0000000000000..9c854ad3782ed --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkExpressionCompiler.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.nativeworker.AbstractTestExpressionCompiler; +import com.facebook.presto.testing.QueryRunner; + +public class TestPrestoSparkExpressionCompiler + extends AbstractTestExpressionCompiler +{ + @Override + protected QueryRunner getQueryRunner() + { + return PrestoSparkNativeQueryRunnerUtils.createHiveRunner(); + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkSqlFunctions.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkSqlFunctions.java new file mode 100644 index 0000000000000..a9de4232645ba --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkSqlFunctions.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.operator.scalar.TestCustomFunctions; +import com.facebook.presto.sql.analyzer.FeaturesConfig; + +public class TestPrestoSparkSqlFunctions + extends TestCustomFunctions +{ + public TestPrestoSparkSqlFunctions() + { + super(new FeaturesConfig().setNativeExecutionEnabled(true)); + } +} diff --git a/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/TestPinotQueryBase.java b/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/TestPinotQueryBase.java index 4e6f618c30f6f..df13f993f9416 100644 --- a/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/TestPinotQueryBase.java +++ b/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/TestPinotQueryBase.java @@ -235,7 +235,7 @@ protected RowExpression toRowExpression(Expression expression, Session session) expression, ImmutableMap.of(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, expressionTypes, ImmutableMap.of(), functionAndTypeManager.getFunctionAndTypeResolver(), session); + return SqlToRowExpressionTranslator.translate(expression, expressionTypes, ImmutableMap.of(), functionAndTypeManager, session); } protected LimitNode limit(PlanBuilder pb, long count, PlanNode source)