diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index d7d5527cd069..7d726dcbeccd 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -15,6 +15,7 @@ import com.google.inject.Inject; import io.trino.Session; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.BigintType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.IntegerType; @@ -22,14 +23,12 @@ import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; @@ -37,6 +36,13 @@ import java.util.OptionalDouble; +import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; +import static io.trino.spi.function.OperatorType.ADD; +import static io.trino.spi.function.OperatorType.DIVIDE; +import static io.trino.spi.function.OperatorType.MODULUS; +import static io.trino.spi.function.OperatorType.MULTIPLY; +import static io.trino.spi.function.OperatorType.NEGATION; +import static io.trino.spi.function.OperatorType.SUBTRACT; import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation; import static io.trino.util.MoreMath.max; import static io.trino.util.MoreMath.min; @@ -109,6 +115,21 @@ protected SymbolStatsEstimate visitConstant(Constant node, Void context) @Override protected SymbolStatsEstimate visitCall(Call node, Void context) { + if (node.function().getName().equals(builtinFunctionName(NEGATION))) { + SymbolStatsEstimate stats = process(node.arguments().getFirst()); + return SymbolStatsEstimate.buildFrom(stats) + .setLowValue(-stats.getHighValue()) + .setHighValue(-stats.getLowValue()) + .build(); + } + else if (node.function().getName().equals(builtinFunctionName(ADD)) || + node.function().getName().equals(builtinFunctionName(SUBTRACT)) || + node.function().getName().equals(builtinFunctionName(MULTIPLY)) || + node.function().getName().equals(builtinFunctionName(DIVIDE)) || + node.function().getName().equals(builtinFunctionName(MODULUS))) { + return processArithmetic(node); + } + IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session); Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); @@ -175,22 +196,11 @@ private boolean isIntegralType(Type type) return false; } - @Override - protected SymbolStatsEstimate visitNegation(Negation node, Void context) - { - SymbolStatsEstimate stats = process(node.value()); - return SymbolStatsEstimate.buildFrom(stats) - .setLowValue(-stats.getHighValue()) - .setHighValue(-stats.getLowValue()) - .build(); - } - - @Override - protected SymbolStatsEstimate visitArithmetic(Arithmetic node, Void context) + protected SymbolStatsEstimate processArithmetic(Call node) { requireNonNull(node, "node is null"); - SymbolStatsEstimate left = process(node.left()); - SymbolStatsEstimate right = process(node.right()); + SymbolStatsEstimate left = process(node.arguments().get(0)); + SymbolStatsEstimate right = process(node.arguments().get(1)); if (left.isUnknown() || right.isUnknown()) { return SymbolStatsEstimate.unknown(); } @@ -208,11 +218,11 @@ protected SymbolStatsEstimate visitArithmetic(Arithmetic node, Void context) result.setLowValue(NaN) .setHighValue(NaN); } - else if (node.operator() == Arithmetic.Operator.DIVIDE && rightLow < 0 && rightHigh > 0) { + else if (node.function().getName().equals(builtinFunctionName(DIVIDE)) && rightLow < 0 && rightHigh > 0) { result.setLowValue(Double.NEGATIVE_INFINITY) .setHighValue(Double.POSITIVE_INFINITY); } - else if (node.operator() == Arithmetic.Operator.MODULUS) { + else if (node.function().getName().equals(builtinFunctionName(MODULUS))) { double maxDivisor = max(abs(rightLow), abs(rightHigh)); if (leftHigh <= 0) { result.setLowValue(max(-maxDivisor, leftLow)) @@ -228,10 +238,10 @@ else if (leftLow >= 0) { } } else { - double v1 = operate(node.operator(), leftLow, rightLow); - double v2 = operate(node.operator(), leftLow, rightHigh); - double v3 = operate(node.operator(), leftHigh, rightLow); - double v4 = operate(node.operator(), leftHigh, rightHigh); + double v1 = operate(node.function().getName(), leftLow, rightLow); + double v2 = operate(node.function().getName(), leftLow, rightHigh); + double v3 = operate(node.function().getName(), leftHigh, rightLow); + double v4 = operate(node.function().getName(), leftHigh, rightHigh); double lowValue = min(v1, v2, v3, v4); double highValue = max(v1, v2, v3, v4); @@ -242,21 +252,16 @@ else if (leftLow >= 0) { return result.build(); } - private double operate(Arithmetic.Operator operator, double left, double right) + private double operate(CatalogSchemaFunctionName function, double left, double right) { - switch (operator) { - case ADD: - return left + right; - case SUBTRACT: - return left - right; - case MULTIPLY: - return left * right; - case DIVIDE: - return left / right; - case MODULUS: - return left % right; - } - throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator); + return switch (function) { + case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(ADD)) -> left + right; + case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(SUBTRACT)) -> left - right; + case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(MULTIPLY)) -> left * right; + case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(DIVIDE)) -> left / right; + case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(MODULUS)) -> left % right; + default -> throw new IllegalStateException("Unsupported binary arithmetic operation: " + function); + }; } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java b/core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java deleted file mode 100644 index c31d196dfcb8..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.ir; - -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.ResolvedFunction; -import io.trino.spi.type.Type; - -import java.util.List; - -import static io.trino.sql.ir.IrUtils.validateType; -import static java.util.Objects.requireNonNull; - -@JsonSerialize -public record Arithmetic(ResolvedFunction function, Operator operator, Expression left, Expression right) - implements Expression -{ - public enum Operator - { - ADD("+"), - SUBTRACT("-"), - MULTIPLY("*"), - DIVIDE("/"), - MODULUS("%"); - private final String value; - - Operator(String value) - { - this.value = value; - } - - public String getValue() - { - return value; - } - } - - public Arithmetic - { - requireNonNull(operator, "operator is null"); - validateType(function.getSignature().getArgumentType(0), left); - validateType(function.getSignature().getArgumentType(1), right); - } - - @Override - public Type type() - { - return function.getSignature().getReturnType(); - } - - @Override - public R accept(IrVisitor visitor, C context) - { - return visitor.visitArithmetic(this, context); - } - - @Override - public List children() - { - return ImmutableList.of(left, right); - } - - @Override - public String toString() - { - return "%s(%s, %s)".formatted(operator.getValue(), left, right); - } -} diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java index cf0951283c93..25509d2dfad1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java @@ -23,15 +23,6 @@ protected Void visitCast(Cast node, C context) return null; } - @Override - protected Void visitArithmetic(Arithmetic node, C context) - { - process(node.left(), context); - process(node.right(), context); - - return null; - } - @Override protected Void visitBetween(Between node, C context) { @@ -125,13 +116,6 @@ protected Void visitBind(Bind node, C context) return null; } - @Override - protected Void visitNegation(Negation node, C context) - { - process(node.value(), context); - return null; - } - @Override protected Void visitNot(Not node, C context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java index 6874d85afeae..1d44c86eab1f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java @@ -24,8 +24,6 @@ @Immutable @JsonTypeInfo(use = JsonTypeInfo.Id.NAME) @JsonSubTypes({ - @JsonSubTypes.Type(value = Arithmetic.class, name = "arithmetic"), - @JsonSubTypes.Type(value = Negation.class, name = "negation"), @JsonSubTypes.Type(value = Between.class, name = "between"), @JsonSubTypes.Type(value = Bind.class, name = "bind"), @JsonSubTypes.Type(value = Cast.class, name = "cast"), @@ -46,9 +44,9 @@ @JsonSubTypes.Type(value = Reference.class, name = "reference"), }) public sealed interface Expression - permits Arithmetic, Between, Bind, Call, Case, Cast, Coalesce, + permits Between, Bind, Call, Case, Cast, Coalesce, Comparison, Constant, FieldReference, In, IsNull, Lambda, Logical, - Negation, Not, NullIf, Reference, Row, Switch + Not, NullIf, Reference, Row, Switch { Type type(); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java index 566b91b8be48..fb22e84e7bbc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java @@ -163,18 +163,6 @@ protected String visitCoalesce(Coalesce node, Void context) return "COALESCE(" + joinExpressions(node.operands()) + ")"; } - @Override - protected String visitNegation(Negation node, Void context) - { - return "-(" + process(node.value(), context) + ")"; - } - - @Override - protected String visitArithmetic(Arithmetic node, Void context) - { - return formatBinaryExpression(node.operator().getValue(), node.left(), node.right()); - } - @Override public String visitCast(Cast node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java index 3d5cbb662af5..8a50b216e237 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java @@ -25,16 +25,6 @@ public Expression rewriteRow(Row node, C context, ExpressionTreeRewriter tree return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteNegation(Negation node, C context, ExpressionTreeRewriter treeRewriter) - { - return rewriteExpression(node, context, treeRewriter); - } - - public Expression rewriteArithmetic(Arithmetic node, C context, ExpressionTreeRewriter treeRewriter) - { - return rewriteExpression(node, context, treeRewriter); - } - public Expression rewriteComparison(Comparison node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java index a8e8f4eceaac..f7e28a45158c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java @@ -97,44 +97,6 @@ protected Expression visitRow(Row node, Context context) return node; } - @Override - protected Expression visitNegation(Negation node, Context context) - { - if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteNegation(node, context.get(), ExpressionTreeRewriter.this); - if (result != null) { - return result; - } - } - - Expression child = rewrite(node.value(), context.get()); - if (child != node.value()) { - return new Negation(child); - } - - return node; - } - - @Override - public Expression visitArithmetic(Arithmetic node, Context context) - { - if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteArithmetic(node, context.get(), ExpressionTreeRewriter.this); - if (result != null) { - return result; - } - } - - Expression left = rewrite(node.left(), context.get()); - Expression right = rewrite(node.right(), context.get()); - - if (left != node.left() || right != node.right()) { - return new Arithmetic(node.function(), node.operator(), left, right); - } - - return node; - } - @Override protected Expression visitFieldReference(FieldReference node, Context context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java index 1f6ef21321ae..97b65655291b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java @@ -32,11 +32,6 @@ protected R visitExpression(Expression node, C context) return null; } - protected R visitArithmetic(Arithmetic node, C context) - { - return visitExpression(node, context); - } - protected R visitBetween(Between node, C context) { return visitExpression(node, context); @@ -82,11 +77,6 @@ protected R visitNullIf(NullIf node, C context) return visitExpression(node, context); } - protected R visitNegation(Negation node, C context) - { - return visitExpression(node, context); - } - protected R visitNot(Not node, C context) { return visitExpression(node, context); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Negation.java b/core/trino-main/src/main/java/io/trino/sql/ir/Negation.java deleted file mode 100644 index 6bf3f7f9f45c..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Negation.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.ir; - -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.google.common.collect.ImmutableList; -import io.trino.spi.type.Type; - -import java.util.List; - -import static java.util.Objects.requireNonNull; - -@JsonSerialize -public record Negation(Expression value) - implements Expression -{ - public Negation - { - requireNonNull(value, "value is null"); - } - - @Override - public Type type() - { - return value.type(); - } - - @Override - public R accept(IrVisitor visitor, C context) - { - return visitor.visitNegation(this, context); - } - - @Override - public List children() - { - return ImmutableList.of(value); - } - - @Override - public String toString() - { - return "-(%s)".formatted(value); - } -} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index e93eceeac7ed..d7856a627621 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -36,7 +36,6 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; @@ -48,7 +47,6 @@ import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -93,6 +91,12 @@ import static io.trino.spi.expression.StandardFunctions.NULLIF_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; +import static io.trino.spi.function.OperatorType.ADD; +import static io.trino.spi.function.OperatorType.DIVIDE; +import static io.trino.spi.function.OperatorType.MODULUS; +import static io.trino.spi.function.OperatorType.MULTIPLY; +import static io.trino.spi.function.OperatorType.NEGATION; +import static io.trino.spi.function.OperatorType.SUBTRACT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; @@ -160,18 +164,6 @@ static FunctionName functionNameForComparisonOperator(Comparison.Operator operat }; } - @VisibleForTesting - static FunctionName functionNameForArithmeticBinaryOperator(Arithmetic.Operator operator) - { - return switch (operator) { - case ADD -> ADD_FUNCTION_NAME; - case SUBTRACT -> SUBTRACT_FUNCTION_NAME; - case MULTIPLY -> MULTIPLY_FUNCTION_NAME; - case DIVIDE -> DIVIDE_FUNCTION_NAME; - case MODULUS -> MODULUS_FUNCTION_NAME; - }; - } - public record ConnectorExpressionTranslation(ConnectorExpression connectorExpression, Expression remainingExpression) { public ConnectorExpressionTranslation @@ -270,7 +262,7 @@ protected Optional translateCall(io.trino.spi.expression.Call call) // arithmetic binary if (call.getArguments().size() == 2) { - Optional operator = arithmeticBinaryOperatorForFunctionName(call.getFunctionName()); + Optional operator = arithmeticBinaryOperatorForFunctionName(call.getFunctionName()); if (operator.isPresent()) { return translateArithmeticBinary(operator.get(), call.getArguments().get(0), call.getArguments().get(1)); } @@ -278,7 +270,9 @@ protected Optional translateCall(io.trino.spi.expression.Call call) // arithmetic unary if (NEGATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { - return translate(getOnlyElement(call.getArguments())).map(argument -> new Negation(argument)); + ConnectorExpression argument = getOnlyElement(call.getArguments()); + ResolvedFunction function = plannerContext.getMetadata().resolveOperator(NEGATION, ImmutableList.of(argument.getType())); + return translate(argument).map(value -> new Call(function, ImmutableList.of(value))); } if (StandardFunctions.LIKE_FUNCTION_NAME.equals(call.getFunctionName())) { @@ -416,38 +410,31 @@ private Optional comparisonOperatorForFunctionName(Function return Optional.empty(); } - private Optional translateArithmeticBinary(Arithmetic.Operator operator, ConnectorExpression left, ConnectorExpression right) + private Optional translateArithmeticBinary(OperatorType operator, ConnectorExpression left, ConnectorExpression right) { - OperatorType operatorType = switch (operator) { - case ADD -> OperatorType.ADD; - case SUBTRACT -> OperatorType.SUBTRACT; - case MULTIPLY -> OperatorType.MULTIPLY; - case DIVIDE -> OperatorType.DIVIDE; - case MODULUS -> OperatorType.MODULUS; - }; - ResolvedFunction function = plannerContext.getMetadata().resolveOperator(operatorType, ImmutableList.of(left.getType(), right.getType())); + ResolvedFunction function = plannerContext.getMetadata().resolveOperator(operator, ImmutableList.of(left.getType(), right.getType())); return translate(left).flatMap(leftTranslated -> translate(right).map(rightTranslated -> - new Arithmetic(function, operator, leftTranslated, rightTranslated))); + new Call(function, ImmutableList.of(leftTranslated, rightTranslated)))); } - private Optional arithmeticBinaryOperatorForFunctionName(FunctionName functionName) + private Optional arithmeticBinaryOperatorForFunctionName(FunctionName functionName) { if (ADD_FUNCTION_NAME.equals(functionName)) { - return Optional.of(Arithmetic.Operator.ADD); + return Optional.of(ADD); } if (SUBTRACT_FUNCTION_NAME.equals(functionName)) { - return Optional.of(Arithmetic.Operator.SUBTRACT); + return Optional.of(SUBTRACT); } if (MULTIPLY_FUNCTION_NAME.equals(functionName)) { - return Optional.of(Arithmetic.Operator.MULTIPLY); + return Optional.of(MULTIPLY); } if (DIVIDE_FUNCTION_NAME.equals(functionName)) { - return Optional.of(Arithmetic.Operator.DIVIDE); + return Optional.of(DIVIDE); } if (MODULUS_FUNCTION_NAME.equals(functionName)) { - return Optional.of(Arithmetic.Operator.MODULUS); + return Optional.of(MODULUS); } return Optional.empty(); } @@ -593,16 +580,6 @@ protected Optional visitComparison(Comparison node, Void co new io.trino.spi.expression.Call(((Expression) node).type(), functionNameForComparisonOperator(node.operator()), ImmutableList.of(left, right)))); } - @Override - protected Optional visitArithmetic(Arithmetic node, Void context) - { - if (!isComplexExpressionPushdown(session)) { - return Optional.empty(); - } - return process(node.left()).flatMap(left -> process(node.right()).map(right -> - new io.trino.spi.expression.Call(((Expression) node).type(), functionNameForArithmeticBinaryOperator(node.operator()), ImmutableList.of(left, right)))); - } - @Override protected Optional visitBetween(Between node, Void context) { @@ -620,13 +597,10 @@ protected Optional visitBetween(Between node, Void context) new io.trino.spi.expression.Call(BOOLEAN, LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, max))))))); } - @Override - protected Optional visitNegation(Negation node, Void context) + protected Optional translateNegation(Call node) { - if (!isComplexExpressionPushdown(session)) { - return Optional.empty(); - } - return process(node.value()).map(value -> new io.trino.spi.expression.Call(((Expression) node).type(), NEGATE_FUNCTION_NAME, ImmutableList.of(value))); + return process(node.arguments().getFirst()) + .map(value -> new io.trino.spi.expression.Call(node.type(), NEGATE_FUNCTION_NAME, ImmutableList.of(value))); } @Override @@ -669,6 +643,29 @@ protected Optional visitCall(Call node, Void context) if (functionName.equals(builtinFunctionName(LIKE_FUNCTION_NAME))) { return translateLike(node); } + else if (functionName.equals(builtinFunctionName(NEGATION))) { + return translateNegation(node); + } + else if (functionName.equals(builtinFunctionName(ADD))) { + return process(node.arguments().get(0)).flatMap(left -> process(node.arguments().get(1)).map(right -> + new io.trino.spi.expression.Call(node.type(), ADD_FUNCTION_NAME, ImmutableList.of(left, right)))); + } + else if (functionName.equals(builtinFunctionName(SUBTRACT))) { + return process(node.arguments().get(0)).flatMap(left -> process(node.arguments().get(1)).map(right -> + new io.trino.spi.expression.Call(node.type(), SUBTRACT_FUNCTION_NAME, ImmutableList.of(left, right)))); + } + else if (functionName.equals(builtinFunctionName(MULTIPLY))) { + return process(node.arguments().get(0)).flatMap(left -> process(node.arguments().get(1)).map(right -> + new io.trino.spi.expression.Call(node.type(), MULTIPLY_FUNCTION_NAME, ImmutableList.of(left, right)))); + } + else if (functionName.equals(builtinFunctionName(DIVIDE))) { + return process(node.arguments().get(0)).flatMap(left -> process(node.arguments().get(1)).map(right -> + new io.trino.spi.expression.Call(node.type(), DIVIDE_FUNCTION_NAME, ImmutableList.of(left, right)))); + } + else if (functionName.equals(builtinFunctionName(MODULUS))) { + return process(node.arguments().get(0)).flatMap(left -> process(node.arguments().get(1)).map(right -> + new io.trino.spi.expression.Call(node.type(), MODULUS_FUNCTION_NAME, ImmutableList.of(left, right)))); + } ImmutableList.Builder arguments = ImmutableList.builder(); for (Expression argumentExpression : node.arguments()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java index 00b601df2041..26a0c8c54450 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java @@ -24,7 +24,6 @@ import io.trino.sql.analyzer.FieldId; import io.trino.sql.analyzer.RelationId; import io.trino.sql.analyzer.ResolvedField; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; @@ -41,7 +40,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static java.util.Objects.requireNonNull; public final class GroupingOperationRewriter @@ -99,11 +97,9 @@ public static Expression rewriteGroupingOperation( .setName(ArrayConstructor.NAME) .setArguments(Collections.nCopies(groupingResults.size(), type), groupingResults) .build(), - new Arithmetic( + new Call( metadata.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)), - ADD, - groupIdSymbol.get().toSymbolReference(), - new Constant(BIGINT, 1L)))); + ImmutableList.of(groupIdSymbol.get().toSymbolReference(), new Constant(BIGINT, 1L))))); } private static int translateFieldToInteger(FieldId fieldId, RelationId requiredOriginRelationId) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index eb3a16d86e2c..e0578e21d320 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.function.FunctionNullability; -import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; @@ -31,7 +30,6 @@ import io.trino.spi.type.Type; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Bind; import io.trino.sql.ir.Call; @@ -48,7 +46,6 @@ import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -76,11 +73,11 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.instanceOf; -import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.block.RowValueBuilder.buildRowValue; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; @@ -88,6 +85,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.spi.function.InvocationConvention.simpleConvention; import static io.trino.spi.function.OperatorType.HASH_CODE; +import static io.trino.spi.function.OperatorType.NEGATION; import static io.trino.spi.type.TypeUtils.readNativeValue; import static io.trino.spi.type.TypeUtils.writeNativeValue; import static io.trino.sql.DynamicFilters.isDynamicFilter; @@ -445,57 +443,6 @@ else if (!found && result) { return false; } - @Override - protected Object visitNegation(Negation node, Object context) - { - Object value = processWithExceptionHandling(node.value(), context); - if (value == null) { - return null; - } - if (value instanceof Expression) { - Expression valueExpression = toExpression(value, node.value().type()); - if (valueExpression instanceof Negation argument) { - return argument.value(); - } - return new Negation(valueExpression); - } - - ResolvedFunction resolvedOperator = metadata.resolveOperator(OperatorType.NEGATION, types(node.value())); - InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), FAIL_ON_NULL, true, false); - MethodHandle handle = plannerContext.getFunctionManager().getScalarFunctionImplementation(resolvedOperator, invocationConvention).getMethodHandle(); - - if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) { - handle = handle.bindTo(connectorSession); - } - try { - return handle.invokeWithArguments(value); - } - catch (Throwable throwable) { - throwIfInstanceOf(throwable, RuntimeException.class); - throwIfInstanceOf(throwable, Error.class); - throw new RuntimeException(throwable.getMessage(), throwable); - } - } - - @Override - protected Object visitArithmetic(Arithmetic node, Object context) - { - Object left = processWithExceptionHandling(node.left(), context); - if (left == null) { - return null; - } - Object right = processWithExceptionHandling(node.right(), context); - if (right == null) { - return null; - } - - if (hasUnresolvedValue(left, right)) { - return new Arithmetic(node.function(), node.operator(), toExpression(left, node.left().type()), toExpression(right, node.right().type())); - } - - return functionInvoker.invoke(node.function(), connectorSession, ImmutableList.of(left, right)); - } - @Override protected Object visitComparison(Comparison node, Object context) { @@ -727,6 +674,10 @@ protected Object visitLogical(Logical node, Object context) @Override protected Object visitCall(Call node, Object context) { + if (node.function().getName().getFunctionName().equals(mangleOperatorName(NEGATION))) { + return processNegation(node, context); + } + List argumentTypes = new ArrayList<>(); List argumentValues = new ArrayList<>(); for (Expression expression : node.arguments()) { @@ -757,6 +708,18 @@ protected Object visitCall(Call node, Object context) return functionInvoker.invoke(resolvedFunction, connectorSession, argumentValues); } + private Object processNegation(Call negation, Object context) + { + Object value = processWithExceptionHandling(negation.arguments().getFirst(), context); + + return switch (value) { + case Call inner when inner.function().getName().getFunctionName().equals(mangleOperatorName(NEGATION)) -> inner.arguments().getFirst(); // double negation + case Expression inner -> new Call(negation.function(), ImmutableList.of(inner)); + case null -> null; + default -> functionInvoker.invoke(negation.function(), connectorSession, ImmutableList.of(value)); + }; + } + @Override protected Object visitLambda(Lambda node, Object context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java index 3b2468bb089a..25210a943375 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java @@ -106,8 +106,10 @@ protected Void visitFieldReference(FieldReference node, AtomicBoolean result) @Override protected Void visitCall(Call node, AtomicBoolean result) { - // TODO: this should look at whether the return type of the function is annotated with @SqlNullable - result.set(true); + if (node.function().getFunctionNullability().isReturnNullable()) { + result.set(true); + } + return null; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index 3a9d65f8eb4d..ac211d2f500d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -40,7 +40,6 @@ import io.trino.sql.analyzer.ResolvedField; import io.trino.sql.analyzer.Scope; import io.trino.sql.analyzer.TypeSignatureTranslator; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; @@ -52,7 +51,6 @@ import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -380,7 +378,9 @@ private io.trino.sql.ir.Expression translate(ArithmeticUnaryExpression expressio { return switch (expression.getSign()) { case PLUS -> translateExpression(expression.getValue()); - case MINUS -> new Negation(translateExpression(expression.getValue())); + case MINUS -> new io.trino.sql.ir.Call( + plannerContext.getMetadata().resolveOperator(OperatorType.NEGATION, ImmutableList.of(analysis.getType(expression.getValue()))), + ImmutableList.of(translateExpression(expression.getValue()))); }; } @@ -578,17 +578,11 @@ private io.trino.sql.ir.Expression translate(ArithmeticBinaryExpression expressi case MODULUS -> OperatorType.MODULUS; }; - return new Arithmetic( + return new Call( plannerContext.getMetadata().resolveOperator(operatorType, ImmutableList.of(getCoercedType(expression.getLeft()), getCoercedType(expression.getRight()))), - switch (expression.getOperator()) { - case ADD -> Arithmetic.Operator.ADD; - case SUBTRACT -> Arithmetic.Operator.SUBTRACT; - case MULTIPLY -> Arithmetic.Operator.MULTIPLY; - case DIVIDE -> Arithmetic.Operator.DIVIDE; - case MODULUS -> Arithmetic.Operator.MODULUS; - }, - translateExpression(expression.getLeft()), - translateExpression(expression.getRight())); + ImmutableList.of( + translateExpression(expression.getLeft()), + translateExpression(expression.getRight()))); } private Type getCoercedType(Expression left) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java index 440d9cfb95d4..a6868e96fa67 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java @@ -21,7 +21,6 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -33,11 +32,12 @@ import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.spi.type.DateType.DATE; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; public final class CanonicalizeExpressionRewriter { + private static final CatalogSchemaFunctionName MULTIPLY_BUILTIN_FUNCTION = builtinFunctionName(OperatorType.MULTIPLY); + private static final CatalogSchemaFunctionName ADD_BUILTIN_FUNCTION = builtinFunctionName(OperatorType.ADD); + public static Expression canonicalizeExpression(Expression expression, PlannerContext plannerContext) { return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext), expression); @@ -77,38 +77,30 @@ public Expression rewriteComparison(Comparison node, Void context, ExpressionTre return treeRewriter.defaultRewrite(node, context); } - @SuppressWarnings("ArgumentSelectionDefectChecker") @Override - public Expression rewriteArithmetic(Arithmetic node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter treeRewriter) { - if (node.operator() == MULTIPLY || node.operator() == ADD) { - // if we have a operation of the form [+|*] , normalize it to - // [+|*] - if (isConstant(node.left()) && !isConstant(node.right())) { - node = new Arithmetic( + CatalogSchemaFunctionName functionName = node.function().getName(); + + if (functionName.equals(MULTIPLY_BUILTIN_FUNCTION) || + functionName.equals(ADD_BUILTIN_FUNCTION)) { + // normalize [*/+] , normalize it to [*/+] + Expression left = treeRewriter.rewrite(node.arguments().get(0), context); + Expression right = treeRewriter.rewrite(node.arguments().get(1), context); + if (isConstant(left) && !isConstant(right)) { + return new Call( plannerContext.getMetadata().resolveOperator( - switch (node.operator()) { - case ADD -> OperatorType.ADD; - case MULTIPLY -> OperatorType.MULTIPLY; - default -> throw new IllegalStateException("Unexpected value: " + node.operator()); - }, + getOperator(functionName), ImmutableList.of( node.function().getSignature().getArgumentType(1), node.function().getSignature().getArgumentType(0))), - node.operator(), - node.right(), - node.left()); + ImmutableList.of(right, left)); + } + else { + return new Call(node.function(), ImmutableList.of(left, right)); } } - - return treeRewriter.defaultRewrite(node, context); - } - - @Override - public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter treeRewriter) - { - CatalogSchemaFunctionName functionName = node.function().getName(); - if (functionName.equals(builtinFunctionName("date")) && node.arguments().size() == 1) { + else if (functionName.equals(builtinFunctionName("date")) && node.arguments().size() == 1) { Expression argument = node.arguments().get(0); Type argumentType = argument.type(); if (argumentType instanceof TimestampType @@ -127,4 +119,13 @@ private boolean isConstant(Expression expression) return expression instanceof Constant; } } + + private static OperatorType getOperator(CatalogSchemaFunctionName function) + { + return switch (function) { + case CatalogSchemaFunctionName name when name.equals(ADD_BUILTIN_FUNCTION) -> OperatorType.ADD; + case CatalogSchemaFunctionName name when name.equals(MULTIPLY_BUILTIN_FUNCTION) -> OperatorType.MULTIPLY; + default -> throw new IllegalArgumentException("Unexpected operator: " + function); + }; + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java index fc4888650c9e..e4beb73c86ff 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java @@ -19,7 +19,6 @@ import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -33,7 +32,6 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.plan.Patterns.Except.distinct; import static io.trino.sql.planner.plan.Patterns.except; @@ -104,11 +102,9 @@ public Result apply(ExceptNode node, Captures captures, Context context) count = new Call( greatest, ImmutableList.of( - new Arithmetic( + new Call( metadata.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(BIGINT, BIGINT)), - SUBTRACT, - count, - result.getCountSymbols().get(i).toSymbolReference()), + ImmutableList.of(count, result.getCountSymbols().get(i).toSymbolReference())), new Constant(BIGINT, 0L))); } diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index 5e3da5f41a3f..e0e6ec615a93 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -20,7 +20,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Bind; import io.trino.sql.ir.Call; @@ -37,7 +36,6 @@ import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -58,7 +56,6 @@ import static io.trino.spi.function.OperatorType.HASH_CODE; import static io.trino.spi.function.OperatorType.INDETERMINATE; import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; -import static io.trino.spi.function.OperatorType.NEGATION; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -212,27 +209,6 @@ protected RowExpression visitBind(Bind node, Void context) return new SpecialForm(BIND, ((Expression) node).type(), argumentsBuilder.build()); } - @Override - protected RowExpression visitArithmetic(Arithmetic node, Void context) - { - RowExpression left = process(node.left(), context); - RowExpression right = process(node.right(), context); - - return call( - standardFunctionResolution.arithmeticFunction(node.operator(), left.getType(), right.getType()), - left, - right); - } - - @Override - protected RowExpression visitNegation(Negation node, Void context) - { - RowExpression expression = process(node.value(), context); - return call( - metadata.resolveOperator(NEGATION, ImmutableList.of(expression.getType())), - expression); - } - @Override protected RowExpression visitLogical(Logical node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java b/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java index 18fa35ff36ef..397e61876740 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java @@ -18,18 +18,12 @@ import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; -import io.trino.sql.ir.Arithmetic.Operator; import io.trino.sql.ir.Comparison; -import static io.trino.spi.function.OperatorType.ADD; -import static io.trino.spi.function.OperatorType.DIVIDE; import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.trino.spi.function.OperatorType.LESS_THAN; import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; -import static io.trino.spi.function.OperatorType.MODULUS; -import static io.trino.spi.function.OperatorType.MULTIPLY; -import static io.trino.spi.function.OperatorType.SUBTRACT; import static java.util.Objects.requireNonNull; public final class StandardFunctionResolution @@ -41,31 +35,6 @@ public StandardFunctionResolution(Metadata metadata) this.metadata = requireNonNull(metadata, "metadata is null"); } - public ResolvedFunction arithmeticFunction(Operator operator, Type leftType, Type rightType) - { - OperatorType operatorType; - switch (operator) { - case ADD: - operatorType = ADD; - break; - case SUBTRACT: - operatorType = SUBTRACT; - break; - case MULTIPLY: - operatorType = MULTIPLY; - break; - case DIVIDE: - operatorType = DIVIDE; - break; - case MODULUS: - operatorType = MODULUS; - break; - default: - throw new IllegalStateException("Unknown arithmetic operator: " + operator); - } - return metadata.resolveOperator(operatorType, ImmutableList.of(leftType, rightType)); - } - public ResolvedFunction comparisonFunction(Comparison.Operator operator, Type leftType, Type rightType) { OperatorType operatorType; diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java index 2e12fac74e16..7bb9f3a6b752 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -35,7 +35,6 @@ import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.iterative.rule.test.PlanBuilder.aggregation; @@ -146,7 +145,7 @@ public void testFilterAndProjectOverAggregationStats() return pb.filter( new Comparison(GREATER_THAN, new Reference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), // Non-narrowing projection - pb.project(Assignments.of(pb.symbol("x_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)), aggregatedOutput, aggregatedOutput.toSymbolReference()), + pb.project(Assignments.of(pb.symbol("x_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))), aggregatedOutput, aggregatedOutput.toSymbolReference()), pb.aggregation(ab -> ab .addAggregation(aggregatedOutput, aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT)) diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index 4f91125f2f6f..6f8fa91c0762 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -26,7 +26,6 @@ import io.trino.spi.type.Decimals; import io.trino.spi.type.VarcharType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; @@ -37,7 +36,6 @@ import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; @@ -55,9 +53,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; @@ -173,6 +168,7 @@ public class TestFilterStatsCalculator private static final ResolvedFunction MULTIPLY_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(DOUBLE, DOUBLE)); private static final ResolvedFunction SUBTRACT_INTEGER = FUNCTIONS.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction NEGATION_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(DOUBLE)); @Test public void testBooleanLiteralStats() @@ -195,13 +191,13 @@ public void testComparison() .distinctValuesCount(26) .nullsFraction(0.0)); - assertExpression(new Comparison(GREATER_THAN, new Negation(new Reference(DOUBLE, "x")), new Constant(DOUBLE, -3.0))) + assertExpression(new Comparison(GREATER_THAN, new Call(NEGATION_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "x"))), new Constant(DOUBLE, -3.0))) .outputRowsCount(lessThan3Rows); for (Expression minusThree : ImmutableList.of( new Constant(createDecimalType(3), Decimals.valueOfShort(new BigDecimal("-3"))), new Constant(DOUBLE, -3.0), - new Arithmetic(SUBTRACT_DOUBLE, SUBTRACT, new Constant(DOUBLE, 4.0), new Constant(DOUBLE, 7.0)), new Cast(new Constant(INTEGER, -3L), createDecimalType(7, 3)))) { + new Call(SUBTRACT_DOUBLE, ImmutableList.of(new Constant(DOUBLE, 4.0), new Constant(DOUBLE, 7.0))), new Cast(new Constant(INTEGER, -3L), createDecimalType(7, 3)))) { assertExpression(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Cast(minusThree, DOUBLE))) .outputRowsCount(18.75) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> @@ -223,7 +219,7 @@ public void testComparison() assertExpression(new Comparison( EQUAL, new Coalesce( - new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Reference(DOUBLE, "x"), new Constant(DOUBLE, null)), + new Call(MULTIPLY_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "x"), new Constant(DOUBLE, null))), new Reference(DOUBLE, "x")), new Cast(minusThree, DOUBLE))) .outputRowsCount(18.75) @@ -260,29 +256,29 @@ public void testInequalityComparisonApproximation() assertExpression(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "emptyRange"))) .outputRowsCount(0); - assertExpression(new Comparison(GREATER_THAN, new Reference(INTEGER, "x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 20L)))) + assertExpression(new Comparison(GREATER_THAN, new Reference(INTEGER, "x"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "y"), new Constant(INTEGER, 20L))))) .outputRowsCount(0); - assertExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 20L)))) + assertExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "x"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "y"), new Constant(INTEGER, 20L))))) .outputRowsCount(0); - assertExpression(new Comparison(LESS_THAN, new Reference(INTEGER, "x"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "y"), new Constant(INTEGER, 25L)))) + assertExpression(new Comparison(LESS_THAN, new Reference(INTEGER, "x"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "y"), new Constant(INTEGER, 25L))))) .outputRowsCount(0); - assertExpression(new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "x"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "y"), new Constant(INTEGER, 25L)))) + assertExpression(new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "x"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "y"), new Constant(INTEGER, 25L))))) .outputRowsCount(0); double nullsFractionY = 0.5; double inputRowCount = standardInputStatistics.getOutputRowCount(); double nonNullRowCount = inputRowCount * (1 - nullsFractionY); SymbolStatsEstimate nonNullStatsX = xStats.mapNullsFraction(nullsFraction -> 0.0); - assertExpression(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Arithmetic(SUBTRACT_DOUBLE, SUBTRACT, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 25.0)))) + assertExpression(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Call(SUBTRACT_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "y"), new Constant(DOUBLE, 25.0))))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); - assertExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(SUBTRACT_DOUBLE, SUBTRACT, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 25.0)))) + assertExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Call(SUBTRACT_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "y"), new Constant(DOUBLE, 25.0))))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); - assertExpression(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Arithmetic(ADD_DOUBLE, ADD, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 20.0)))) + assertExpression(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "y"), new Constant(DOUBLE, 20.0))))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); - assertExpression(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_DOUBLE, ADD, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 20.0)))) + assertExpression(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "y"), new Constant(DOUBLE, 20.0))))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); } @@ -361,7 +357,7 @@ public void testAndStats() .nullsFraction(0.0)); // Impossible, with symbol-to-expression comparisons - assertExpression(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_DOUBLE, ADD, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 1.0))), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_DOUBLE, ADD, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 3.0)))))) + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Call(ADD_DOUBLE, ImmutableList.of(new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 1.0)))), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Call(ADD_DOUBLE, ImmutableList.of(new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 3.0))))))) .outputRowsCount(0) .symbolStats(new Symbol(UNKNOWN, "x"), SymbolStatsAssertion::emptyRange) .symbolStats(new Symbol(UNKNOWN, "y"), SymbolStatsAssertion::emptyRange); @@ -750,7 +746,7 @@ public void testInPredicateFilter() .lowValue(-7.5) .highValue(-7.5) .nullsFraction(0.0)); - assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Arithmetic(ADD_DOUBLE, ADD, new Constant(DOUBLE, 2.0), new Constant(DOUBLE, 5.5))))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Call(ADD_DOUBLE, ImmutableList.of(new Constant(DOUBLE, 2.0), new Constant(DOUBLE, 5.5)))))) .outputRowsCount(18.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(1.0) diff --git a/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java index 628f95c2a097..be476b31d0d8 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java @@ -24,7 +24,7 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; @@ -46,11 +46,6 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.TransactionBuilder.transaction; import static io.trino.type.UnknownType.UNKNOWN; @@ -320,26 +315,26 @@ public void testNonDivideArithmeticBinaryExpression() .setOutputRowCount(10) .build(); - assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), relationStats) + assertCalculate(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), relationStats) .distinctValuesCount(10.0) .lowValue(-3.0) .highValue(15.0) .nullsFraction(0.28) .averageRowSize(2.0); - assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "x"), new Reference(BIGINT, "unknown")), relationStats) + assertCalculate(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "unknown"))), relationStats) .isEqualTo(SymbolStatsEstimate.unknown()); - assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "unknown"), new Reference(BIGINT, "unknown")), relationStats) + assertCalculate(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "unknown"), new Reference(BIGINT, "unknown"))), relationStats) .isEqualTo(SymbolStatsEstimate.unknown()); - assertCalculate(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), relationStats) + assertCalculate(new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), relationStats) .distinctValuesCount(10.0) .lowValue(-6.0) .highValue(12.0) .nullsFraction(0.28) .averageRowSize(2.0); - assertCalculate(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), relationStats) + assertCalculate(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), relationStats) .distinctValuesCount(10.0) .lowValue(-20.0) .highValue(50.0) @@ -363,95 +358,95 @@ public void testArithmeticBinaryWithAllNullsSymbol() .setOutputRowCount(10) .build(); - assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) + assertCalculate(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) + assertCalculate(new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) + assertCalculate(new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) + assertCalculate(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null"))), relationStats) .isEqualTo(allNullStats); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x"))), relationStats) .isEqualTo(allNullStats); } @Test public void testDivideArithmeticBinaryExpression() { - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, -3, -5, -4)).lowValue(0.6).highValue(2.75); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, -3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, -3, 4, 5)).lowValue(-2.75).highValue(-0.6); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, -3, -5, -4)).lowValue(0.6).highValue(2.75); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, -3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, -3, 4, 5)).lowValue(-2.75).highValue(-0.6); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 0, -5, -4)).lowValue(0).highValue(2.75); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 0, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 0, 4, 5)).lowValue(-2.75).highValue(0); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, 0, -5, -4)).lowValue(0).highValue(2.75); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, 0, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, 0, 4, 5)).lowValue(-2.75).highValue(0); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 3, -5, -4)).lowValue(-0.75).highValue(2.75); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 3, 4, 5)).lowValue(-2.75).highValue(0.75); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, 3, -5, -4)).lowValue(-0.75).highValue(2.75); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-11, 3, 4, 5)).lowValue(-2.75).highValue(0.75); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 3, -5, -4)).lowValue(-0.75).highValue(0); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 3, 4, 5)).lowValue(0).highValue(0.75); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 3, -5, -4)).lowValue(-0.75).highValue(0); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 3, 4, 5)).lowValue(0).highValue(0.75); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(3, 11, -5, -4)).lowValue(-2.75).highValue(-0.6); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(3, 11, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(3, 11, 4, 5)).lowValue(0.6).highValue(2.75); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(3, 11, -5, -4)).lowValue(-2.75).highValue(-0.6); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(3, 11, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(3, 11, 4, 5)).lowValue(0.6).highValue(2.75); } @Test public void testModulusArithmeticBinaryExpression() { // negative - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 0, -6, -4)).lowValue(-1).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 0, -6, -4)).lowValue(-5).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, 4)).lowValue(-6).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, 6)).lowValue(-6).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 0, 4, 6)).lowValue(-1).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 0, 4, 6)).lowValue(-5).highValue(0); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 0, -6, -4)).lowValue(-1).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 0, -6, -4)).lowValue(-5).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 0, -6, 4)).lowValue(-6).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 0, -6, 6)).lowValue(-6).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 0, 4, 6)).lowValue(-1).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 0, 4, 6)).lowValue(-5).highValue(0); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); // positive - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 5, -6, -4)).lowValue(0).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 8, -6, -4)).lowValue(0).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 1, -6, 4)).lowValue(0).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 5, -6, 4)).lowValue(0).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 8, -6, 4)).lowValue(0).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 1, 4, 6)).lowValue(0).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 5, 4, 6)).lowValue(0).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 8, 4, 6)).lowValue(0).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 5, -6, -4)).lowValue(0).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 8, -6, -4)).lowValue(0).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 1, -6, 4)).lowValue(0).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 5, -6, 4)).lowValue(0).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 8, -6, 4)).lowValue(0).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 1, 4, 6)).lowValue(0).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 5, 4, 6)).lowValue(0).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(0, 8, 4, 6)).lowValue(0).highValue(6); // mix - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 1, -6, -4)).lowValue(-1).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 5, -6, -4)).lowValue(-1).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 1, -6, -4)).lowValue(-5).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 5, -6, -4)).lowValue(-5).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 8, -6, -4)).lowValue(-5).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 5, -6, -4)).lowValue(-6).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 8, -6, -4)).lowValue(-6).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 1, -6, 4)).lowValue(-1).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 5, -6, 4)).lowValue(-1).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 1, -6, 4)).lowValue(-5).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 5, -6, 4)).lowValue(-5).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 8, -6, 4)).lowValue(-5).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 5, -6, 4)).lowValue(-6).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 8, -6, 4)).lowValue(-6).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 1, 4, 6)).lowValue(-1).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 5, 4, 6)).lowValue(-1).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 1, 4, 6)).lowValue(-5).highValue(1); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 5, 4, 6)).lowValue(-5).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 8, 4, 6)).lowValue(-5).highValue(6); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 5, 4, 6)).lowValue(-6).highValue(5); - assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 8, 4, 6)).lowValue(-6).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 1, -6, -4)).lowValue(-1).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 5, -6, -4)).lowValue(-1).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 1, -6, -4)).lowValue(-5).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 5, -6, -4)).lowValue(-5).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 8, -6, -4)).lowValue(-5).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 5, -6, -4)).lowValue(-6).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 8, -6, -4)).lowValue(-6).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 1, -6, 4)).lowValue(-1).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 5, -6, 4)).lowValue(-1).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 1, -6, 4)).lowValue(-5).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 5, -6, 4)).lowValue(-5).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 8, -6, 4)).lowValue(-5).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 5, -6, 4)).lowValue(-6).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 8, -6, 4)).lowValue(-6).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 1, 4, 6)).lowValue(-1).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-1, 5, 4, 6)).lowValue(-1).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 1, 4, 6)).lowValue(-5).highValue(1); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 5, 4, 6)).lowValue(-5).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-5, 8, 4, 6)).lowValue(-5).highValue(6); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 5, 4, 6)).lowValue(-6).highValue(5); + assertCalculate(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), xyStats(-8, 8, 4, 6)).lowValue(-6).highValue(6); } private PlanNodeStatsEstimate xyStats(double lowX, double highX, double lowY, double highY) diff --git a/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java b/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java index a353332280ae..581e952cbff5 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java @@ -19,7 +19,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import org.junit.jupiter.api.Test; @@ -29,7 +29,6 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; import static io.trino.type.UnknownType.UNKNOWN; public class TestValuesNodeStats @@ -93,7 +92,7 @@ public void testDivisionByZero() { tester().assertStatsFor(pb -> pb .values(ImmutableList.of(pb.symbol("a", BIGINT)), - ImmutableList.of(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)))))) + ImmutableList.of(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 0L))))))) .check(outputStats -> outputStats.equalTo(unknown())); } diff --git a/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java b/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java index de24ec574c33..87c2bea25fc6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java @@ -34,7 +34,6 @@ import io.trino.sql.PlannerContext; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.gen.PageFunctionCompiler; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -82,8 +81,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; @@ -218,10 +215,10 @@ private List createInputPages(List types) private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Cast(new Reference(VARCHAR, "varchar0"), INTEGER), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Cast(new Reference(VARCHAR, "varchar0"), INTEGER), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L))); } if (type == BIGINT) { - return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L))); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -231,7 +228,7 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L)))); + builder.add(rowExpression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L))))); } } else if (type == VARCHAR) { diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java index 398513d5960a..37117fa514f2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java @@ -24,7 +24,6 @@ import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; @@ -36,7 +35,6 @@ import io.trino.sql.ir.In; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Switch; @@ -64,10 +62,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Logical.Operator.AND; @@ -93,6 +87,7 @@ public class TestPageFieldsToInputParametersRewriter private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction MULTIPLY_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); @Test public void testEagerLoading() @@ -100,20 +95,20 @@ public void testEagerLoading() RowExpressionBuilder builder = RowExpressionBuilder.create() .addSymbol("bigint0", BIGINT) .addSymbol("bigint1", BIGINT); - verifyEagerlyLoadedColumns(builder.buildExpression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 5L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Cast(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 10L)), INTEGER)), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "bigint0"))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 5L)))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Cast(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 10L))), INTEGER)), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 2L))), new Reference(BIGINT, "bigint0"))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new In(new Reference(BIGINT, "bigint0"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L), new Constant(BIGINT, 3L)))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(EQUAL, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 1L)), new Constant(BIGINT, 0L))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(EQUAL, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 1L))), new Constant(BIGINT, 0L))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Between(new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 1L), new Constant(BIGINT, 10L))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L)), new Reference(BIGINT, "bigint0"))), Optional.of(new Constant(BIGINT, null)))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Switch(new Reference(BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BIGINT, 1L), new Constant(BIGINT, 1L))), Optional.of(new Negation(new Reference(BIGINT, "bigint0"))))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Arithmetic(ADD_BIGINT, ADD, new Coalesce(new Constant(BIGINT, 0L), new Reference(BIGINT, "bigint0")), new Reference(BIGINT, "bigint0"))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Switch(new Reference(BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BIGINT, 1L), new Constant(BIGINT, 1L))), Optional.of(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0")))))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Coalesce(new Constant(BIGINT, 0L), new Reference(BIGINT, "bigint0")), new Reference(BIGINT, "bigint0")))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint0"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 2L), new Reference(BIGINT, "bigint1")))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Constant(BIGINT, 2L), new Reference(BIGINT, "bigint1")))))), 2); verifyEagerlyLoadedColumns(builder.buildExpression(new NullIf(new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1"))), 2); - verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(CEIL, ImmutableList.of(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1")))), new Constant(BIGINT, 0L))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(CEIL, ImmutableList.of(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1"))))), new Constant(BIGINT, 0L))), 2); verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1")), new Constant(INTEGER, 1L))), Optional.of(new Constant(INTEGER, 0L)))), 2); verifyEagerlyLoadedColumns( builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Reference(BIGINT, "bigint1"))), Optional.of(new Constant(BIGINT, 0L)))), 2, ImmutableSet.of(0)); @@ -126,8 +121,8 @@ public void testEagerLoading() .addSymbol("array_bigint0", new ArrayType(BIGINT)) .addSymbol("array_bigint1", new ArrayType(BIGINT)); verifyEagerlyLoadedColumns(builder.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(BIGINT, "x")), new Constant(INTEGER, 1L))))), 1, ImmutableSet.of()); - verifyEagerlyLoadedColumns(builder.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(BIGINT, "x")), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 2L), new Reference(INTEGER, "x")))))), 1, ImmutableSet.of()); - verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ZIP_WITH, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Reference(new ArrayType(BIGINT), "array_bigint1"), new Lambda(ImmutableList.of(new Symbol(BIGINT, "x"), new Symbol(BIGINT, "y")), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 2L), new Reference(BIGINT, "x")))))), 2, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(BIGINT, "x")), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Constant(INTEGER, 2L), new Reference(INTEGER, "x"))))))), 1, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ZIP_WITH, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Reference(new ArrayType(BIGINT), "array_bigint1"), new Lambda(ImmutableList.of(new Symbol(BIGINT, "x"), new Symbol(BIGINT, "y")), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Constant(BIGINT, 2L), new Reference(BIGINT, "x"))))))), 2, ImmutableSet.of()); } private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount) diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index a822c064a0d2..5a6eece76c48 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -19,7 +19,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; @@ -32,7 +31,6 @@ import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -59,10 +57,6 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ExpressionTestUtils.assertExpressionEquals; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; @@ -102,6 +96,7 @@ public class TestExpressionInterpreter private static final ResolvedFunction SUBTRACT_INTEGER = FUNCTIONS.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction MULTIPLY_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction DIVIDE_INTEGER = FUNCTIONS.resolveOperator(OperatorType.DIVIDE, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction NEGATION_INTEGER = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(INTEGER)); @Test public void testAnd() @@ -235,7 +230,7 @@ public void testIsNull() new IsNull(new Constant(INTEGER, 1L)), FALSE); assertOptimizedEquals( - new IsNull(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, null), new Constant(INTEGER, 1L))), + new IsNull(new Call(ADD_INTEGER, ImmutableList.of(new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), TRUE); } @@ -249,7 +244,7 @@ public void testIsNotNull() new Not(new IsNull(new Constant(INTEGER, 1L))), TRUE); assertOptimizedEquals( - new Not(new IsNull(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), + new Not(new IsNull(new Call(ADD_INTEGER, ImmutableList.of(new Constant(INTEGER, null), new Constant(INTEGER, 1L))))), FALSE); } @@ -277,11 +272,11 @@ public void testNullIf() public void testNegative() { assertOptimizedEquals( - new Negation(new Constant(INTEGER, 1L)), + new Call(NEGATION_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, -1L)); assertOptimizedEquals( - new Negation(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))), - new Negation(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)))); + new Call(NEGATION_INTEGER, ImmutableList.of(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))))), + new Call(NEGATION_INTEGER, ImmutableList.of(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)))))); } @Test @@ -394,32 +389,32 @@ public void testIn() new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))); assertOptimizedEquals( - new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))), + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))); assertOptimizedEquals( - new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))); assertOptimizedEquals( - new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))), + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))); assertOptimizedEquals( - new In(new Constant(INTEGER, 3L), ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new In(new Constant(INTEGER, 3L), ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); - assertTrinoExceptionThrownBy(() -> evaluate(new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))) + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))); + assertTrinoExceptionThrownBy(() -> evaluate(new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))))) .hasErrorCode(DIVISION_BY_ZERO); assertOptimizedEquals( - new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L)))); + new In(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + new In(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L)))); assertOptimizedEquals( - new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), - new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))); + new In(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), + new In(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))); assertOptimizedEquals( - new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L))), - new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))); + new In(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L))), + new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))); } @Test @@ -488,60 +483,60 @@ public void testSearchCase() assertOptimizedMatches( new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), Optional.empty()), new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), Optional.empty())); assertOptimizedEquals( new Case(ImmutableList.of( - new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("b")))), + new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("b")))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("b"))))); assertOptimizedEquals( new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c"))))); assertOptimizedEquals( new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.empty()), new Case(ImmutableList.of( - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), Optional.empty())); assertOptimizedEquals( new Case(ImmutableList.of( - new WhenClause(TRUE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new WhenClause(TRUE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new WhenClause(FALSE, new Constant(INTEGER, 1L))), Optional.empty()), - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(FALSE, new Constant(INTEGER, 1L)), new WhenClause(FALSE, new Constant(INTEGER, 2L))), - Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertEvaluatedEquals( new Case(ImmutableList.of( - new WhenClause(FALSE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new WhenClause(TRUE, new Constant(INTEGER, 1L))), - Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(FALSE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new WhenClause(TRUE, new Constant(INTEGER, 1L))), + Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( new Case(ImmutableList.of( - new WhenClause(TRUE, new Constant(INTEGER, 1L)), new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), - Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(TRUE, new Constant(INTEGER, 1L)), new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), new Constant(INTEGER, 1L)); } @@ -609,34 +604,34 @@ public void testSimpleCase() TRUE, ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 33L))), new Switch( TRUE, ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), - new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 33L)))); assertOptimizedMatches( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 1L))), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 1L)))); assertOptimizedEquals( new Switch( new Constant(INTEGER, null), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), Optional.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)); assertOptimizedEquals( @@ -644,16 +639,16 @@ public void testSimpleCase() new Constant(INTEGER, null), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( new Switch( - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L))), new Switch( - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L)))); @@ -661,35 +656,35 @@ public void testSimpleCase() new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L))), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L)))); assertOptimizedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 3L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 3L))), Optional.of(new Constant(INTEGER, 4L))), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 3L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 3L))), Optional.of(new Constant(INTEGER, 4L)))); assertOptimizedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), Optional.empty()), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), Optional.empty())); assertOptimizedEquals( new Switch( @@ -704,14 +699,14 @@ public void testSimpleCase() new Switch( new Constant(INTEGER, null), ImmutableList.of( - new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), Optional.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Constant(INTEGER, 2L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Constant(INTEGER, 2L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), Optional.of(new Constant(INTEGER, 3L))), new Constant(INTEGER, 3L)); assertEvaluatedEquals( @@ -719,7 +714,7 @@ public void testSimpleCase() new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), - new WhenClause(new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), Optional.empty()), new Constant(INTEGER, 2L)); assertEvaluatedEquals( @@ -727,7 +722,7 @@ public void testSimpleCase() new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), new Constant(INTEGER, 2L)); } @@ -735,8 +730,8 @@ public void testSimpleCase() public void testCoalesce() { assertOptimizedEquals( - new Coalesce(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "unbound_value"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L))), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)), new Constant(INTEGER, null)), - new Coalesce(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 6L)), new Constant(INTEGER, 0L))); + new Coalesce(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "unbound_value"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L))))), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))), new Constant(INTEGER, null)), + new Coalesce(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 6L))), new Constant(INTEGER, 0L))); assertOptimizedMatches( new Coalesce(new Reference(INTEGER, "unbound_value"), new Reference(INTEGER, "unbound_value")), new Reference(INTEGER, "unbound_value")); @@ -754,28 +749,28 @@ public void testCoalesce() new Coalesce(new Constant(INTEGER, null), new Coalesce(new Constant(INTEGER, null), new Coalesce(new Constant(INTEGER, null), new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new Coalesce(new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Coalesce(new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), - new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))); + new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), + new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, null)), - new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))); + new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L), new Constant(INTEGER, null)), + new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new Coalesce(new Constant(INTEGER, 1L), new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new Coalesce(new Constant(INTEGER, 1L), new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)))); + new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, null), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 0L))), new Constant(INTEGER, null), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 0L))))); assertOptimizedEquals( new Coalesce(new Call(RANDOM, ImmutableList.of()), new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0), new Call(RANDOM, ImmutableList.of())), new Coalesce(new Call(RANDOM, ImmutableList.of()), new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0))); assertEvaluatedEquals( - new Coalesce(new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Coalesce(new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); - assertTrinoExceptionThrownBy(() -> evaluate(new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new Coalesce(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)))) .hasErrorCode(DIVISION_BY_ZERO); } @@ -819,28 +814,28 @@ public void testIf() new Constant(UNKNOWN, null)); assertOptimizedEquals( - ifExpression(TRUE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + ifExpression(TRUE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( - ifExpression(TRUE, new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + ifExpression(TRUE, new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - ifExpression(FALSE, new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + ifExpression(FALSE, new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( - ifExpression(FALSE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + ifExpression(FALSE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); assertOptimizedEquals( - ifExpression(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), - ifExpression(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))); + ifExpression(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), + ifExpression(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))); assertEvaluatedEquals( - ifExpression(TRUE, new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + ifExpression(TRUE, new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( - ifExpression(FALSE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + ifExpression(FALSE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); - assertTrinoExceptionThrownBy(() -> evaluate(ifExpression(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) + assertTrinoExceptionThrownBy(() -> evaluate(ifExpression(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) .hasErrorCode(DIVISION_BY_ZERO); } @@ -848,10 +843,10 @@ public void testIf() public void testOptimizeDivideByZero() { assertOptimizedEquals( - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), - new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); - assertTrinoExceptionThrownBy(() -> evaluate(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))) .hasErrorCode(DIVISION_BY_ZERO); } @@ -881,15 +876,15 @@ public void testRowSubscript() new FieldReference(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(UNKNOWN, null))), 1), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 0), - new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 0)); + new FieldReference(new Row(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))), 0), + new FieldReference(new Row(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))), 0)); assertOptimizedEquals( - new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1), - new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1)); + new FieldReference(new Row(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))), 1), + new FieldReference(new Row(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))), 1)); - assertTrinoExceptionThrownBy(() -> evaluate(new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1))) + assertTrinoExceptionThrownBy(() -> evaluate(new FieldReference(new Row(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))), 1))) .hasErrorCode(DIVISION_BY_ZERO); - assertTrinoExceptionThrownBy(() -> evaluate(new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1))) + assertTrinoExceptionThrownBy(() -> evaluate(new FieldReference(new Row(ImmutableList.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))), 1))) .hasErrorCode(DIVISION_BY_ZERO); } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java index d9fdb7acf61d..28a3bf85bda1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java @@ -31,7 +31,6 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -69,8 +68,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.type.UnknownType.UNKNOWN; @@ -163,10 +160,10 @@ public List> columnOriented() private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Cast(new Reference(VARCHAR, "varchar0"), BIGINT), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Cast(new Reference(VARCHAR, "varchar0"), BIGINT), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L))); } if (type == BIGINT) { - return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L))); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -176,7 +173,7 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L)))); + builder.add(rowExpression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L))))); } } else if (type == VARCHAR) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java index c343c56af43e..e1ef492f251b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java @@ -20,7 +20,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -47,10 +46,6 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -309,8 +304,8 @@ public void testPredicatePushDownOverProjection() "SELECT * FROM t WHERE x + x > 1", anyTree( filter( - new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "expr"), new Reference(BIGINT, "expr")), new Constant(BIGINT, 1L)), - project(ImmutableMap.of("expr", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)))), + new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "expr"), new Reference(BIGINT, "expr"))), new Constant(BIGINT, 1L)), + project(ImmutableMap.of("expr", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L))))), tableScan("orders", ImmutableMap.of("orderkey", "orderkey")))))); // constant non-singleton should be pushed down @@ -320,7 +315,7 @@ public void testPredicatePushDownOverProjection() anyTree( project( filter( - new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Constant(BIGINT, 1L)), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Call(ADD_BIGINT, ImmutableList.of(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L))), new Constant(BIGINT, 1L))), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); @@ -331,7 +326,7 @@ public void testPredicatePushDownOverProjection() anyTree( project( filter( - new Comparison(GREATER_THAN, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); @@ -342,7 +337,7 @@ public void testPredicatePushDownOverProjection() anyTree( project( filter( - new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "orderkey")), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L))), new Reference(BIGINT, "orderkey"))), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); @@ -363,7 +358,7 @@ public void testPredicatePushDownOverProjection() anyTree( filter( new Comparison(GREATER_THAN, new Reference(DOUBLE, "expr"), new Constant(DOUBLE, 5000.0)), - project(ImmutableMap.of("expr", expression(new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Call(RANDOM, ImmutableList.of()), new Cast(new Reference(BIGINT, "orderkey"), DOUBLE)))), + project(ImmutableMap.of("expr", expression(new Call(MULTIPLY_DOUBLE, ImmutableList.of(new Call(RANDOM, ImmutableList.of()), new Cast(new Reference(BIGINT, "orderkey"), DOUBLE))))), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); } @@ -399,7 +394,7 @@ public void testConjunctsOrder() // Order matters: size<>1 should be before 100/(size-1)=2. // In this particular example, reversing the order leads to div-by-zero error. filter( - new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "size"), new Constant(INTEGER, 1L)), new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 100L), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "size"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "size"), new Constant(INTEGER, 1L)), new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 100L), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "size"), new Constant(INTEGER, 1L))))), new Constant(INTEGER, 2L)))), tableScan("part", ImmutableMap.of( "partkey", "partkey", "size", "size"))))); @@ -442,7 +437,7 @@ public void testPredicateOnNonDeterministicSymbolsPushedDown() anyTree( filter( new Comparison(GREATER_THAN, new Reference(DOUBLE, "ROUND"), new Constant(DOUBLE, 100.0)), - project(ImmutableMap.of("ROUND", expression(new Call(ROUND, ImmutableList.of(new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Cast(new Reference(BIGINT, "CUST_KEY"), DOUBLE), new Call(RANDOM, ImmutableList.of())))))), + project(ImmutableMap.of("ROUND", expression(new Call(ROUND, ImmutableList.of(new Call(MULTIPLY_DOUBLE, ImmutableList.of(new Cast(new Reference(BIGINT, "CUST_KEY"), DOUBLE), new Call(RANDOM, ImmutableList.of()))))))), tableScan( "orders", ImmutableMap.of("CUST_KEY", "custkey")))))))); @@ -458,7 +453,7 @@ public void testNonDeterministicPredicateNotPushedDown() ") WHERE custkey > 100*rand()", anyTree( filter( - new Comparison(GREATER_THAN, new Cast(new Reference(BIGINT, "CUST_KEY"), DOUBLE), new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 100.0))), + new Comparison(GREATER_THAN, new Cast(new Reference(BIGINT, "CUST_KEY"), DOUBLE), new Call(MULTIPLY_DOUBLE, ImmutableList.of(new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 100.0)))), anyTree( node(WindowNode.class, anyTree( @@ -559,7 +554,7 @@ public void testDoesNotCreatePredicateFromInferredPredicate() join(INNER, builder -> builder .equiCriteria("L_NATIONKEY2", "R_NATIONKEY") .left( - project(ImmutableMap.of("L_NATIONKEY2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 1L)))), + project(ImmutableMap.of("L_NATIONKEY2", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 1L))))), tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) .right( anyTree( @@ -586,7 +581,7 @@ public void testSimplifiesStraddlingPredicate() assertPlan("SELECT * FROM (SELECT * FROM NATION WHERE nationkey = 5) a JOIN nation b ON a.nationkey = b.nationkey AND a.nationkey = a.regionkey + b.regionkey", output( filter( - new Comparison(EQUAL, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_REGIONKEY"), new Reference(BIGINT, "R_REGIONKEY")), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_REGIONKEY"), new Reference(BIGINT, "R_REGIONKEY"))), new Constant(BIGINT, 5L)), join(INNER, builder -> builder .equiCriteria(ImmutableList.of()) .left( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java index aa5c6c570f86..3e06d01bc8fb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java @@ -24,8 +24,8 @@ import io.trino.operator.RetryPolicy; import io.trino.spi.function.OperatorType; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Logical; @@ -53,8 +53,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; @@ -82,8 +80,8 @@ public class TestAddDynamicFilterSource extends BasePlanTest { private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); - private static final ResolvedFunction SUBTRACT_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); - private static final ResolvedFunction MODULUS_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction SUBTRACT_BIGINT = FUNCTIONS.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction MODULUS_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(INTEGER, INTEGER)); public TestAddDynamicFilterSource() { @@ -175,7 +173,7 @@ public void testSemiJoin() DynamicFilterSourceNode.class, project( filter( - new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "Z"), new Constant(INTEGER, 4L)), new Constant(INTEGER, 0L)), + new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "Z"), new Constant(INTEGER, 4L))), new Constant(INTEGER, 0L)), tableScan("lineitem", ImmutableMap.of("Y", "orderkey", "Z", "linenumber"))))))))))); } } @@ -275,7 +273,7 @@ public void testCrossJoinInequality() exchange( LOCAL, project( - ImmutableMap.of("expr", expression(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L))))), exchange( REMOTE, tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey", "L_PARTKEY", "partkey")))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index f19b70f10349..3b87d03e148a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -19,6 +19,7 @@ import io.trino.Session; import io.trino.metadata.Metadata; import io.trino.metadata.MetadataManager; +import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.scalar.JsonPath; import io.trino.security.AllowAllAccessControl; @@ -31,7 +32,6 @@ import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; @@ -42,7 +42,6 @@ import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -52,6 +51,7 @@ import io.trino.type.LikeFunctions; import org.junit.jupiter.api.Test; +import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -59,15 +59,25 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp; +import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.MODULUS_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NOT_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NULLIF_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; +import static io.trino.spi.function.OperatorType.ADD; +import static io.trino.spi.function.OperatorType.DIVIDE; +import static io.trino.spi.function.OperatorType.MODULUS; +import static io.trino.spi.function.OperatorType.MULTIPLY; +import static io.trino.spi.function.OperatorType.SUBTRACT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DecimalType.createDecimalType; @@ -98,6 +108,9 @@ public class TestConnectorExpressionTranslator private static final VarcharType VARCHAR_TYPE = createUnboundedVarcharType(); private static final ArrayType VARCHAR_ARRAY_TYPE = new ArrayType(VARCHAR_TYPE); + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction NEGATION_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(DOUBLE)); + private static final Map symbols = ImmutableMap.builder() .put(new Symbol(DOUBLE, "double_symbol_1"), DOUBLE) .put(new Symbol(DOUBLE, "double_symbol_2"), DOUBLE) @@ -195,33 +208,35 @@ public void testTranslateComparisonExpression() public void testTranslateArithmeticBinary() { TestingFunctionResolution resolver = new TestingFunctionResolution(); - for (Arithmetic.Operator operator : Arithmetic.Operator.values()) { + for (OperatorType operator : EnumSet.of(ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS)) { assertTranslationRoundTrips( - new Arithmetic( - resolver.resolveOperator( - switch (operator) { - case ADD -> OperatorType.ADD; - case SUBTRACT -> OperatorType.SUBTRACT; - case MULTIPLY -> OperatorType.MULTIPLY; - case DIVIDE -> OperatorType.DIVIDE; - case MODULUS -> OperatorType.MODULUS; - }, - ImmutableList.of(DOUBLE, DOUBLE)), + new Call(resolver.resolveOperator( operator, - new Reference(DOUBLE, "double_symbol_1"), - new Reference(DOUBLE, "double_symbol_2")), + ImmutableList.of(DOUBLE, DOUBLE)), ImmutableList.of(new Reference(DOUBLE, "double_symbol_1"), new Reference(DOUBLE, "double_symbol_2"))), new io.trino.spi.expression.Call( DOUBLE, - ConnectorExpressionTranslator.functionNameForArithmeticBinaryOperator(operator), + functionNameForArithmeticBinaryOperator(operator), List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE)))); } } + private static FunctionName functionNameForArithmeticBinaryOperator(OperatorType operator) + { + return switch (operator) { + case ADD -> ADD_FUNCTION_NAME; + case SUBTRACT -> SUBTRACT_FUNCTION_NAME; + case MULTIPLY -> MULTIPLY_FUNCTION_NAME; + case DIVIDE -> DIVIDE_FUNCTION_NAME; + case MODULUS -> MODULUS_FUNCTION_NAME; + default -> throw new IllegalArgumentException("Unsupported operator: " + operator); + }; + } + @Test public void testTranslateArithmeticUnaryMinus() { assertTranslationRoundTrips( - new Negation(new Reference(DOUBLE, "double_symbol_1")), + new Call(NEGATION_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "double_symbol_1"))), new io.trino.spi.expression.Call(DOUBLE, NEGATE_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE)))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java index d22b5794b707..c072a699887d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java @@ -20,7 +20,6 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -35,7 +34,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; @@ -220,7 +218,7 @@ public void testDereferencePushdownLimit() assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))), ROW(CAST(ROW(3, 4.0) AS ROW(x BIGINT, y DOUBLE))))" + "SELECT msg.x * 3 FROM t limit 1", anyTree( - strictProject(ImmutableMap.of("x_into_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "msg_x"), new Constant(BIGINT, 3L)))), + strictProject(ImmutableMap.of("x_into_3", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "msg_x"), new Constant(BIGINT, 3L))))), limit(1, strictProject(ImmutableMap.of("msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg"), 0))), values("msg")))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java index 0a49be5d7231..fedf55c0de19 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java @@ -23,7 +23,6 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; @@ -55,8 +54,6 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -158,7 +155,7 @@ public void testRightEquiJoinWithLeftExpression() .right( anyTree( project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L))))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))))); } @@ -246,7 +243,7 @@ public void testCrossJoinBetweenDF() assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey BETWEEN l.orderkey AND l.partkey - 1", anyTree(filter( - new Between(new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"), new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L))), + new Between(new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"), new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L)))), join(INNER, builder -> builder .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), GREATER_THAN_OR_EQUAL, "L_ORDERKEY"))) .left( @@ -259,7 +256,7 @@ public void testCrossJoinBetweenDF() assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey BETWEEN l.orderkey + 1 AND l.partkey", anyTree(filter( - new Between(new Reference(BIGINT, "O_ORDERKEY"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_ORDERKEY"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "L_PARTKEY")), + new Between(new Reference(BIGINT, "O_ORDERKEY"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_ORDERKEY"), new Constant(BIGINT, 1L))), new Reference(BIGINT, "L_PARTKEY")), join(INNER, builder -> builder .dynamicFilter(ImmutableList.of( new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), LESS_THAN_OR_EQUAL, "L_PARTKEY"))) @@ -654,7 +651,7 @@ public void testNonPushedDownJoinFilterRemoval() .equiCriteria(ImmutableList.of(equiJoinClause("K0", "K2"), equiJoinClause("S", "V2"))) .left( project( - ImmutableMap.of("S", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "V0"), new Reference(BIGINT, "V1")))), + ImmutableMap.of("S", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "V0"), new Reference(BIGINT, "V1"))))), join(INNER, leftJoinBuilder -> leftJoinBuilder .equiCriteria("K0", "K1") .dynamicFilter(BIGINT, "K0", "K1") @@ -766,7 +763,7 @@ public void testSemiJoinUnsupportedDynamicFilterRemoval() filter( new Reference(BOOLEAN, "S0"), semiJoin("LINEITEM_PK_PLUS_1000", "PART_PK", "S0", false, - project(ImmutableMap.of("LINEITEM_PK_PLUS_1000", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_PK"), new Constant(BIGINT, 1000L)))), + project(ImmutableMap.of("LINEITEM_PK_PLUS_1000", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "LINEITEM_PK"), new Constant(BIGINT, 1000L))))), tableScan("lineitem", ImmutableMap.of("LINEITEM_PK", "partkey"))), anyTree( tableScan("part", ImmutableMap.of("PART_PK", "partkey"))))))); @@ -782,7 +779,7 @@ public void testExpressionPushedDownToLeftJoinSourceWhenUsingOn() join(INNER, builder -> builder .left( project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L))))), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( anyTree( @@ -804,7 +801,7 @@ public void testExpressionPushedDownToRightJoinSourceWhenUsingOn() .right( anyTree( project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L))))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))))); } @@ -813,7 +810,7 @@ public void testExpressionNotPushedDownToLeftJoinSource() { assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey + 1 < l.orderkey", anyTree(filter( - new Comparison(LESS_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "LINEITEM_OK")), + new Comparison(LESS_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L))), new Reference(BIGINT, "LINEITEM_OK")), join(INNER, builder -> builder .left(tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))) .right(exchange( @@ -836,7 +833,7 @@ public void testExpressionPushedDownToRightJoinSource() .right( anyTree( project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L))))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java index 7c0599cf4bb5..86614ce2e218 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java @@ -22,7 +22,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.scalar.TryFunction; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -361,12 +361,12 @@ private static Expression add(String symbol1, String symbol2) private static Expression add(Expression expression1, Expression expression2) { - return new Arithmetic(ADD_BIGINT, Arithmetic.Operator.ADD, expression1, expression2); + return new Call(ADD_BIGINT, ImmutableList.of(expression1, expression2)); } private static Expression multiply(Expression expression1, Expression expression2) { - return new Arithmetic(MULTIPLY_BIGINT, Arithmetic.Operator.MULTIPLY, expression1, expression2); + return new Call(MULTIPLY_BIGINT, ImmutableList.of(expression1, expression2)); } private static Expression equals(String symbol1, String symbol2) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index d16bbcc2ae96..a018fce1f91f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -30,7 +30,6 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Coalesce; @@ -116,9 +115,6 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -529,10 +525,10 @@ public void testInequalityPredicatePushdownWithOuterJoin() anyTree( // predicate above outer join is not pushed to build side filter( - new Comparison(LESS_THAN, new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 24L)), new Coalesce(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 24L)), new Constant(BIGINT, 0L))), + new Comparison(LESS_THAN, new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 24L))), new Coalesce(new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 24L))), new Constant(BIGINT, 0L))), join(LEFT, builder -> builder .equiCriteria("O_ORDERKEY", "L_ORDERKEY") - .filter(new Comparison(LESS_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 42L)), new Reference(BIGINT, "EXPR"))) + .filter(new Comparison(LESS_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 42L))), new Reference(BIGINT, "EXPR"))) .left( tableScan( "orders", @@ -542,7 +538,7 @@ public void testInequalityPredicatePushdownWithOuterJoin() .right( anyTree( project( - ImmutableMap.of("EXPR", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 42L)))), + ImmutableMap.of("EXPR", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 42L))))), tableScan( "lineitem", ImmutableMap.of( @@ -2415,9 +2411,9 @@ public void testMergePatternRecognitionNodesWithProjections() project( ImmutableMap.of( "output1", expression(new Reference(INTEGER, "id")), - "output2", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "value"), new Constant(INTEGER, 2L))), + "output2", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "value"), new Constant(INTEGER, 2L)))), "output3", expression(new Call(LOWER, ImmutableList.of(new Reference(VARCHAR, "label")))), - "output4", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "min"), new Constant(INTEGER, 1L)))), + "output4", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "min"), new Constant(INTEGER, 1L))))), project( ImmutableMap.of( "id", expression(new Reference(INTEGER, "id")), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java index d78811bd436e..9f744580af09 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java @@ -40,7 +40,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeParameter; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -65,7 +65,6 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; @@ -321,7 +320,7 @@ public void testMaterializedViewWithCasts() anyTree( project( ImmutableMap.of( - "A_CAST", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Reference(BIGINT, "A"), BIGINT), new Constant(BIGINT, 1L))), + "A_CAST", expression(new Call(ADD_BIGINT, ImmutableList.of(new Cast(new Reference(BIGINT, "A"), BIGINT), new Constant(BIGINT, 1L)))), "B_CAST", expression(new Cast(new Reference(BIGINT, "B"), BIGINT))), tableScan("storage_table_with_casts", ImmutableMap.of("A", "a", "B", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 63ee1ed3cccf..9cdd6fc748fe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -21,7 +21,6 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; @@ -38,7 +37,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.PartialTranslator.extractPartialTranslations; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; @@ -66,7 +64,7 @@ public void testPartialTranslator() assertFullTranslation(symbolReference1); assertFullTranslation(dereferenceExpression1); assertFullTranslation(stringLiteral); - assertFullTranslation(new Arithmetic(ADD_INTEGER, ADD, symbolReference1, dereferenceExpression1)); + assertFullTranslation(new Call(ADD_INTEGER, ImmutableList.of(symbolReference1, dereferenceExpression1))); Expression functionCallExpression = new Call( PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("concat", fromTypes(VARCHAR, VARCHAR)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java index c6fef26f96e4..0e8c5acd7853 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java @@ -19,7 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -35,7 +35,6 @@ import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.columnReference; @@ -124,7 +123,7 @@ public void testAliasExpressionFromProject() { assertMinimallyOptimizedPlan("SELECT orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), - project(ImmutableMap.of("EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), + project(ImmutableMap.of("EXPRESSION", expression(new Call(ADD_BIGINT, ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY"))))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey"))))); } @@ -134,7 +133,7 @@ public void testStrictProject() assertMinimallyOptimizedPlan("SELECT orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), strictProject(ImmutableMap.of( - "EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY"))), + "EXPRESSION", expression(new Call(ADD_BIGINT, ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY"))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey"))))); } @@ -146,7 +145,7 @@ public void testIdentityAliasFromProject() output(ImmutableList.of("ORDERKEY", "EXPRESSION"), project(ImmutableMap.of( "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY")), - "EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), + "EXPRESSION", expression(new Call(ADD_BIGINT, ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY"))))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey"))))); } @@ -253,7 +252,7 @@ public void testStrictProjectExtraSymbols() { assertThatThrownBy(() -> assertMinimallyOptimizedPlan("SELECT discount, orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), - strictProject(ImmutableMap.of("EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Constant(BIGINT, 1L), new Reference(BIGINT, "ORDERKEY"))), "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY"))), + strictProject(ImmutableMap.of("EXPRESSION", expression(new Call(ADD_BIGINT, ImmutableList.of(new Constant(BIGINT, 1L), new Reference(BIGINT, "ORDERKEY")))), "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY"))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey")))))) .isInstanceOf(AssertionError.class) .hasMessageStartingWith("Plan does not match"); @@ -283,7 +282,7 @@ public void testProjectLimitsScope() { assertThatThrownBy(() -> assertMinimallyOptimizedPlan("SELECT 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY"), - project(ImmutableMap.of("EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), + project(ImmutableMap.of("EXPRESSION", expression(new Call(ADD_BIGINT, ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY"))))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey")))))) .isInstanceOf(IllegalStateException.class) .hasMessageMatching("missing expression for alias .*"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java index 77bbd30eb050..b99e4724ea54 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java @@ -20,7 +20,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -38,7 +37,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; @@ -90,7 +88,7 @@ public void testRecursiveQuery() values()))), // first recursion step project(project(project( - ImmutableMap.of("expr_0", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_0", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L))))), filter( new Comparison(LESS_THAN, new Reference(INTEGER, "expr"), new Constant(INTEGER, 6L)), project(project(project( @@ -107,13 +105,13 @@ public void testRecursiveQuery() "count", windowFunction("count", ImmutableList.of(), DEFAULT_FRAME)), project(project(project( - ImmutableMap.of("expr_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L))))), filter( new Comparison(LESS_THAN, new Reference(INTEGER, "expr"), new Constant(INTEGER, 6L)), project( ImmutableMap.of("expr", expression(new Reference(INTEGER, "expr_0"))), project(project(project( - ImmutableMap.of("expr_0", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_0", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L))))), filter( new Comparison(LESS_THAN, new Reference(INTEGER, "expr"), new Constant(INTEGER, 6L)), project(project(project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java index a94496f96bdc..2a0f1823a843 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java @@ -18,7 +18,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; @@ -35,7 +34,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; @@ -74,7 +72,7 @@ public void testGetSortExpression() "b1", new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"))); - assertNoSortExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b2")))); + assertNoSortExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "p1"), new Reference(BIGINT, "b2"))))); assertNoSortExpression(new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1"))))); @@ -89,10 +87,10 @@ public void testGetSortExpression() new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"))); assertGetSortExpression( - new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "p1"), new Constant(BIGINT, 10L))), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p2")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "p1"), new Constant(BIGINT, 10L)))), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p2")))), "b2", new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), - new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "p1"), new Constant(BIGINT, 10L))), + new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "p1"), new Constant(BIGINT, 10L)))), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p2"))); assertGetSortExpression( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java index 8f47f6269e3d..a2d1931bbc8b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java @@ -19,11 +19,9 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.SortOrder; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; @@ -36,7 +34,6 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -64,6 +61,7 @@ public class TestWindowClause private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction SUBTRACT_INTEGER = FUNCTIONS.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction ADD_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(DOUBLE, DOUBLE)); + private static final ResolvedFunction NEGATION_INTEGER = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(INTEGER)); @Test public void testPreprojectExpression() @@ -81,7 +79,7 @@ public void testPreprojectExpression() "max_result", windowFunction("max", ImmutableList.of("b"), DEFAULT_FRAME)), anyTree(project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))))), anyTree(values("a", "b")))))); assertPlan(sql, CREATED, pattern); @@ -116,9 +114,9 @@ public void testPreprojectExpressions() ImmutableMap.of("frame_start", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr_b"), new Reference(INTEGER, "expr_c"))))), anyTree(project( ImmutableMap.of( - "expr_a", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))), - "expr_b", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), - "expr_c", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "c"), new Constant(INTEGER, 3L)))), + "expr_a", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))), + "expr_b", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)))), + "expr_c", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "c"), new Constant(INTEGER, 3L))))), anyTree(values("a", "b", "c"))))))); assertPlan(sql, CREATED, pattern); @@ -141,9 +139,9 @@ public void testWindowFunctionsInSelectAndOrderBy() "max_result", windowFunction("max", ImmutableList.of("minus_a"), DEFAULT_FRAME)), any(project( - ImmutableMap.of("order_by_window_sortkey", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "minus_a"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("order_by_window_sortkey", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "minus_a"), new Constant(INTEGER, 1L))))), project( - ImmutableMap.of("minus_a", expression(new Negation(new Reference(INTEGER, "a")))), + ImmutableMap.of("minus_a", expression(new Call(NEGATION_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"))))), window( windowMatcherBuilder -> windowMatcherBuilder .specification(specification( @@ -154,7 +152,7 @@ public void testWindowFunctionsInSelectAndOrderBy() "array_agg_result", windowFunction("array_agg", ImmutableList.of("a"), DEFAULT_FRAME)), anyTree(project( - ImmutableMap.of("select_window_sortkey", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("select_window_sortkey", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))))), anyTree(values("a")))))))))))); assertPlan(sql, CREATED, pattern); @@ -194,9 +192,9 @@ public void testWindowWithFrameCoercions() project(project( ImmutableMap.of( // sort key based on "a" in source scope - "sortkey", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))), + "sortkey", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))), // frame offset based on "a" in output scope - "frame_offset", expression(new Arithmetic(ADD_DOUBLE, ADD, new Reference(DOUBLE, "new_a"), new Constant(DOUBLE, 1.0)))), + "frame_offset", expression(new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "new_a"), new Constant(DOUBLE, 1.0))))), project(// output expression ImmutableMap.of("new_a", expression(new Constant(DOUBLE, 2E0))), project(project(values("a"))))))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java index 58cca84fbb45..e3106ef081ca 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java @@ -13,7 +13,6 @@ */ package io.trino.sql.planner.assertions; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; @@ -28,7 +27,6 @@ import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; @@ -168,28 +166,6 @@ protected Boolean visitBetween(Between actual, Expression expectedExpression) process(actual.max(), expected.max()); } - @Override - protected Boolean visitNegation(Negation actual, Expression expectedExpression) - { - if (!(expectedExpression instanceof Negation expected)) { - return false; - } - - return process(actual.value(), expected.value()); - } - - @Override - protected Boolean visitArithmetic(Arithmetic actual, Expression expectedExpression) - { - if (!(expectedExpression instanceof Arithmetic expected)) { - return false; - } - - return actual.operator() == expected.operator() && - process(actual.left(), expected.left()) && - process(actual.right(), expected.right()); - } - @Override protected Boolean visitNot(Not actual, Expression expectedExpression) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java index 6255f73f7280..e1491d6fb1e5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java @@ -20,7 +20,6 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; @@ -47,8 +46,6 @@ import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.ExpressionTestUtils.assertExpressionEquals; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; @@ -94,20 +91,20 @@ public void testRewriteIfExpression() public void testCanonicalizeArithmetic() { assertRewritten( - new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))), + new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))); assertRewritten( - new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), - new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Call(ADD_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Reference(INTEGER, "a"))), + new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))); assertRewritten( - new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))), + new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))); assertRewritten( - new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), - new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Call(MULTIPLY_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Reference(INTEGER, "a"))), + new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java index 084927555ed3..4bd7f0dabe56 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java @@ -19,13 +19,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; @@ -42,8 +40,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -67,6 +63,7 @@ public class TestDecorrelateInnerUnnestWithGlobalAggregation private static final ResolvedFunction REGEXP_EXTRACT_ALL = FUNCTIONS.resolveFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction MODULUS_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); @Test public void doesNotFireWithoutGlobalAggregation() @@ -285,7 +282,7 @@ public void testProjectOverGlobalAggregation() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(innerBuilder -> innerBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) @@ -298,7 +295,7 @@ public void testProjectOverGlobalAggregation() .matches( project( strictProject( - ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "unique", expression(new Reference(BIGINT, "unique")), "sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "unique", expression(new Reference(BIGINT, "unique")), "sum_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("unnested_corr"))), @@ -372,13 +369,13 @@ public void testMultipleNodesOverUnnestInSubquery() ImmutableList.of(p.symbol("groups"), p.symbol("numbers")), p.values(p.symbol("groups"), p.symbol("numbers")), p.project( - Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(globalBuilder -> globalBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) .source(p.project( Assignments.builder() - .put(p.symbol("negate"), new Negation(new Reference(BIGINT, "max"))) + .put(p.symbol("negate"), new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "max")))) .build(), p.aggregation(groupedBuilder -> groupedBuilder .singleGroupingSet(p.symbol("group")) @@ -387,7 +384,7 @@ public void testMultipleNodesOverUnnestInSubquery() p.project( Assignments.builder() .putIdentities(ImmutableList.of(p.symbol("group"), p.symbol("number"))) - .put(p.symbol("modulo"), new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))) + .put(p.symbol("modulo"), new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))) .build(), p.unnest( ImmutableList.of(), @@ -400,7 +397,7 @@ public void testMultipleNodesOverUnnestInSubquery() .matches( project( project( - ImmutableMap.of("sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("sum_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("groups", "numbers", "unique"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("negated"))), @@ -409,7 +406,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("negated", expression(new Negation(new Reference(BIGINT, "max")))), + ImmutableMap.of("negated", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "max"))))), aggregation( singleGroupingSet("groups", "numbers", "unique", "mask", "group"), ImmutableMap.of(Optional.of("max"), aggregationFunction("max", ImmutableList.of("modulo"))), @@ -418,7 +415,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("modulo", expression(new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))), + ImmutableMap.of("modulo", expression(new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))))), project( ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java index 7b73a14e61a7..e8d0f2b57d5b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java @@ -19,11 +19,9 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -38,8 +36,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; @@ -61,6 +57,7 @@ public class TestDecorrelateLeftUnnestWithGlobalAggregation private static final ResolvedFunction REGEXP_EXTRACT_ALL = FUNCTIONS.resolveFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction MODULUS_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); @Test public void doesNotFireWithoutGlobalAggregation() @@ -264,7 +261,7 @@ public void testProjectOverGlobalAggregation() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(innerBuilder -> innerBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) @@ -280,7 +277,7 @@ public void testProjectOverGlobalAggregation() ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), "unique", expression(new Reference(BIGINT, "unique")), - "sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + "sum_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("unnested_corr"))), @@ -348,13 +345,13 @@ public void testMultipleNodesOverUnnestInSubquery() ImmutableList.of(p.symbol("groups"), p.symbol("numbers")), p.values(p.symbol("groups"), p.symbol("numbers")), p.project( - Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(globalBuilder -> globalBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) .source(p.project( Assignments.builder() - .put(p.symbol("negate"), new Negation(new Reference(BIGINT, "max"))) + .put(p.symbol("negate"), new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "max")))) .build(), p.aggregation(groupedBuilder -> groupedBuilder .singleGroupingSet(p.symbol("group")) @@ -363,7 +360,7 @@ public void testMultipleNodesOverUnnestInSubquery() p.project( Assignments.builder() .putIdentities(ImmutableList.of(p.symbol("group"), p.symbol("number"))) - .put(p.symbol("modulo"), new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))) + .put(p.symbol("modulo"), new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))) .build(), p.unnest( ImmutableList.of(), @@ -376,7 +373,7 @@ public void testMultipleNodesOverUnnestInSubquery() .matches( project( project( - ImmutableMap.of("sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("sum_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("groups", "numbers", "unique"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("negated"))), @@ -384,7 +381,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("negated", expression(new Negation(new Reference(BIGINT, "max")))), + ImmutableMap.of("negated", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "max"))))), aggregation( singleGroupingSet("groups", "numbers", "unique", "group"), ImmutableMap.of(Optional.of("max"), aggregationFunction("max", ImmutableList.of("modulo"))), @@ -393,7 +390,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("modulo", expression(new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))), + ImmutableMap.of("modulo", expression(new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))))), unnest( ImmutableList.of("groups", "numbers", "unique"), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java index aeb9883f6e5a..cf09194592f5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java @@ -24,7 +24,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -57,7 +57,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.enforceSingleRow; @@ -216,10 +215,10 @@ private void testReplicateNoEquiCriteria(JoinType joinType) ImmutableList.of(), ImmutableList.of(p.symbol("A1", BIGINT)), ImmutableList.of(p.symbol("B1", BIGINT)), - Optional.of(new Comparison(GREATER_THAN, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "A1"), new Reference(INTEGER, "B1")), new Constant(INTEGER, 100L))))) + Optional.of(new Comparison(GREATER_THAN, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "A1"), new Reference(INTEGER, "B1"))), new Constant(INTEGER, 100L))))) .matches( join(joinType, builder -> builder - .filter(new Comparison(GREATER_THAN, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "A1"), new Reference(INTEGER, "B1")), new Constant(INTEGER, 100L))) + .filter(new Comparison(GREATER_THAN, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "A1"), new Reference(INTEGER, "B1"))), new Constant(INTEGER, 100L))) .distributionType(REPLICATED) .left(values(ImmutableMap.of("A1", 0))) .right(values(ImmutableMap.of("B1", 0))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java index c9dc6f913825..0aa7ac47dc8c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -18,9 +18,8 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -47,7 +46,6 @@ import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -65,6 +63,7 @@ public class TestEliminateCrossJoins { private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); @@ -229,14 +228,14 @@ public void testEliminateCrossJoinWithNonIdentityProjections() INNER, p.project( Assignments.of( - a2, new Negation(new Reference(BIGINT, "a1")), + a2, new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "a1"))), f, new Reference(BIGINT, "f")), p.join( INNER, p.project( Assignments.of( a1, new Reference(BIGINT, "a1"), - f, new Negation(new Reference(BIGINT, "b"))), + f, new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "b")))), p.join( INNER, p.values(a1), @@ -259,7 +258,7 @@ f, new Negation(new Reference(BIGINT, "b"))), .left( strictProject( ImmutableMap.of( - "a2", expression(new Negation(new Reference(BIGINT, "a1"))), + "a2", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "a1")))), "a1", expression(new Reference(BIGINT, "a1"))), PlanMatchPattern.values("a1"))) .right( @@ -270,7 +269,7 @@ f, new Negation(new Reference(BIGINT, "b"))), .right(any()))) .right( strictProject( - ImmutableMap.of("f", expression(new Negation(new Reference(BIGINT, "b")))), + ImmutableMap.of("f", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"))))), PlanMatchPattern.values("b")))))); } @@ -284,7 +283,7 @@ public void testGiveUpOnComplexProjections() values("a1"), values("b")), "a2", - new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a1"), new Reference(INTEGER, "b")), + new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a1"), new Reference(INTEGER, "b"))), "b", new Reference(INTEGER, "b")), values("c"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java index 173aa5b8f92c..e225de3f2cb1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java @@ -19,7 +19,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -35,7 +34,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -96,7 +94,7 @@ public void test() "a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b"))), filter( - new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Call(GREATEST, ImmutableList.of(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "count_1"), new Reference(BIGINT, "count_2")), new Constant(BIGINT, 0L)))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Call(GREATEST, ImmutableList.of(new Call(SUBTRACT_BIGINT, ImmutableList.of(new Reference(BIGINT, "count_1"), new Reference(BIGINT, "count_2"))), new Constant(BIGINT, 0L)))), strictProject( ImmutableMap.of( "a", expression(new Reference(BIGINT, "a")), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java index 35b3a6b7fa59..4b188f9abe5f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java @@ -20,7 +20,7 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; @@ -37,9 +37,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -62,19 +59,19 @@ public void test() p.project( Assignments.builder() .put(p.symbol("identity"), new Reference(BIGINT, "symbol")) // identity - .put(p.symbol("multi_complex_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex"), new Constant(INTEGER, 1L))) // complex expression referenced multiple times - .put(p.symbol("multi_complex_2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L))) // complex expression referenced multiple times - .put(p.symbol("multi_literal_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "literal"), new Constant(INTEGER, 1L))) // literal referenced multiple times - .put(p.symbol("multi_literal_2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "literal"), new Constant(INTEGER, 2L))) // literal referenced multiple times - .put(p.symbol("single_complex"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex_2"), new Constant(INTEGER, 2L))) // complex expression reference only once - .put(p.symbol("msg_xx"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "z"), new Constant(INTEGER, 1L))) - .put(p.symbol("multi_symbol_reference"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "v"), new Reference(INTEGER, "v"))) + .put(p.symbol("multi_complex_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 1L)))) // complex expression referenced multiple times + .put(p.symbol("multi_complex_2"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L)))) // complex expression referenced multiple times + .put(p.symbol("multi_literal_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "literal"), new Constant(INTEGER, 1L)))) // literal referenced multiple times + .put(p.symbol("multi_literal_2"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "literal"), new Constant(INTEGER, 2L)))) // literal referenced multiple times + .put(p.symbol("single_complex"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex_2"), new Constant(INTEGER, 2L)))) // complex expression reference only once + .put(p.symbol("msg_xx"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "z"), new Constant(INTEGER, 1L)))) + .put(p.symbol("multi_symbol_reference"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "v"), new Reference(INTEGER, "v")))) .build(), p.project(Assignments.builder() .put(p.symbol("symbol"), new Reference(INTEGER, "x")) - .put(p.symbol("complex"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "x"), new Constant(INTEGER, 2L))) + .put(p.symbol("complex"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))) .put(p.symbol("literal"), new Constant(INTEGER, 1L)) - .put(p.symbol("complex_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))) + .put(p.symbol("complex_2"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)))) .put(p.symbol("z"), new FieldReference(new Reference(MSG_TYPE, "msg"), 0)) .put(p.symbol("v"), new Reference(INTEGER, "x")) .build(), @@ -83,18 +80,18 @@ public void test() project( ImmutableMap.builder() .put("out1", PlanMatchPattern.expression(new Reference(INTEGER, "x"))) - .put("out2", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 1L)))) - .put("out3", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 2L)))) - .put("out4", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)))) - .put("out5", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) - .put("out6", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))) - .put("out8", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "z"), new Constant(INTEGER, 1L)))) - .put("out10", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "x"), new Reference(INTEGER, "x")))) + .put("out2", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "y"), new Constant(INTEGER, 1L))))) + .put("out3", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "y"), new Constant(INTEGER, 2L))))) + .put("out4", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))))) + .put("out5", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))))) + .put("out6", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L))))) + .put("out8", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "z"), new Constant(INTEGER, 1L))))) + .put("out10", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Reference(INTEGER, "x"))))) .buildOrThrow(), project( ImmutableMap.of( "x", PlanMatchPattern.expression(new Reference(INTEGER, "x")), - "y", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "x"), new Constant(INTEGER, 2L))), + "y", PlanMatchPattern.expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))), "z", PlanMatchPattern.expression(new FieldReference(new Reference(MSG_TYPE, "msg"), 0))), values(ImmutableMap.of("x", 0, "msg", 1))))); } @@ -112,8 +109,8 @@ public void testInlineEffectivelyLiteral() p.project( Assignments.builder() // Use the literal-like expression multiple times. Single-use expression may be inlined regardless of whether it's a literal - .put(p.symbol("decimal_multiplication"), new Arithmetic(MULTIPLY_DECIMAL_8_4, MULTIPLY, new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal"))) - .put(p.symbol("decimal_addition"), new Arithmetic(ADD_DECIMAL_8_4, ADD, new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal"))) + .put(p.symbol("decimal_multiplication"), new Call(MULTIPLY_DECIMAL_8_4, ImmutableList.of(new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal")))) + .put(p.symbol("decimal_addition"), new Call(ADD_DECIMAL_8_4, ImmutableList.of(new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal")))) .build(), p.project(Assignments.builder() .put(p.symbol("decimal_literal", createDecimalType(8, 4)), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))) @@ -122,8 +119,8 @@ public void testInlineEffectivelyLiteral() .matches( project( Map.of( - "decimal_multiplication", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_DECIMAL_8_4, MULTIPLY, new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))))), - "decimal_addition", PlanMatchPattern.expression(new Arithmetic(ADD_DECIMAL_8_4, ADD, new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))))), + "decimal_multiplication", PlanMatchPattern.expression(new Call(MULTIPLY_DECIMAL_8_4, ImmutableList.of(new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))))), + "decimal_addition", PlanMatchPattern.expression(new Call(ADD_DECIMAL_8_4, ImmutableList.of(new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))))))), values(Map.of("x", 0)))); } @@ -134,15 +131,15 @@ public void testEliminatesIdentityProjection() .on(p -> p.project( Assignments.builder() - .put(p.symbol("single_complex", INTEGER), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L))) // complex expression referenced only once + .put(p.symbol("single_complex", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L)))) // complex expression referenced only once .build(), p.project(Assignments.builder() - .put(p.symbol("complex", INTEGER), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))) + .put(p.symbol("complex", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)))) .build(), p.values(p.symbol("x", INTEGER))))) .matches( project( - ImmutableMap.of("out1", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), + ImmutableMap.of("out1", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L))))), values("x"))); } @@ -194,7 +191,7 @@ public void testSubqueryProjections() p.project( Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value_1")), p.project( - Assignments.of(p.symbol("value_1"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "value"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("value_1"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "value"), new Constant(INTEGER, 1L)))), p.values(p.symbol("value"))))) .matches( project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java index d1964f12ee06..0da8c21eb31e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -18,11 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.Negation; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -53,8 +52,6 @@ import static io.trino.cost.StatsAndCosts.empty; import static io.trino.metadata.AbstractMockMetadata.dummyMetadata; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; @@ -80,6 +77,7 @@ public class TestJoinNodeFlattener private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); private static final ResolvedFunction SUBTRACT_BIGINT = FUNCTIONS.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); private static final int DEFAULT_JOIN_LIMIT = 10; @@ -166,7 +164,7 @@ public void testPushesProjectionsThroughJoin() JoinNode joinNode = p.join( INNER, p.project( - Assignments.of(d, new Negation(a.toSymbolReference())), + Assignments.of(d, new Call(NEGATION_BIGINT, ImmutableList.of(a.toSymbolReference()))), p.join( INNER, valuesA, @@ -210,7 +208,7 @@ public void testDoesNotPushStraddlingProjection() JoinNode joinNode = p.join( INNER, p.project( - Assignments.of(d, new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, a.toSymbolReference(), b.toSymbolReference())), + Assignments.of(d, new Call(SUBTRACT_BIGINT, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference()))), p.join( INNER, valuesA, @@ -290,7 +288,7 @@ public void testCombinesCriteriaAndFilters() new Comparison(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference())); Comparison abcFilter = new Comparison( LESS_THAN, - new Arithmetic(ADD_BIGINT, ADD, a1.toSymbolReference(), c1.toSymbolReference()), + new Call(ADD_BIGINT, ImmutableList.of(a1.toSymbolReference(), c1.toSymbolReference())), b1.toSymbolReference()); JoinNode joinNode = p.join( INNER, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java index b1c26ae9d606..f2c868bfb6d5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java @@ -17,8 +17,8 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Bind; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; @@ -27,7 +27,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter.rewrite; import static org.assertj.core.api.Assertions.assertThat; @@ -43,12 +42,12 @@ public void testRewriteBasicLambda() assertThat( rewrite( - new Lambda(ImmutableList.of(new Symbol(INTEGER, "x")), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Reference(INTEGER, "x"))), + new Lambda(ImmutableList.of(new Symbol(INTEGER, "x")), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Reference(INTEGER, "x")))), allocator)) .isEqualTo(new Bind( ImmutableList.of(new Reference(INTEGER, "a")), new Lambda( ImmutableList.of(new Symbol(INTEGER, "a_0"), new Symbol(INTEGER, "x")), - new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a_0"), new Reference(INTEGER, "x"))))); + new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a_0"), new Reference(INTEGER, "x")))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java index d9808945f59b..0f5cc1fc48b5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java @@ -19,7 +19,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -48,8 +47,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -281,7 +278,7 @@ public void testParentDependsOnSourceCreatedOutputsWithProject() .pattern(new IrLabel("X")) .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( - Assignments.of(p.symbol("projected"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "measure"))), + Assignments.of(p.symbol("projected"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "measure")))), p.patternRecognition(childBuilder -> childBuilder .addMeasure( p.symbol("measure"), @@ -443,7 +440,7 @@ public void testMergeWithProject() .source(p.project( Assignments.of( p.symbol("a"), new Reference(BIGINT, "a"), - p.symbol("expression"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), + p.symbol("expression"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), p.patternRecognition(childBuilder -> { childBuilder .addMeasure( @@ -469,7 +466,7 @@ public void testMergeWithProject() "b", expression(new Reference(BIGINT, "b")), "parent_measure", expression(new Reference(BIGINT, "parent_measure")), "child_measure", expression(new Reference(BIGINT, "child_measure")), - "expression", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), + "expression", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))))), patternRecognition(builder -> builder .addMeasure( "parent_measure", @@ -506,7 +503,7 @@ public void testMergeWithProject() .source(p.project( Assignments.of( p.symbol("a"), new Reference(BIGINT, "a"), - p.symbol("expression"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), new Reference(BIGINT, "child_measure"))), + p.symbol("expression"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), new Reference(BIGINT, "child_measure")))), p.patternRecognition(childBuilder -> { childBuilder .addMeasure( @@ -532,7 +529,7 @@ public void testMergeWithProject() "b", expression(new Reference(BIGINT, "b")), "parent_measure", expression(new Reference(BIGINT, "parent_measure")), "child_measure", expression(new Reference(BIGINT, "child_measure")), - "expression", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), new Reference(BIGINT, "child_measure")))), + "expression", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), new Reference(BIGINT, "child_measure"))))), patternRecognition(builder -> builder .addMeasure( "parent_measure", @@ -574,8 +571,8 @@ public void testMergeWithParentDependingOnProject() .source(p.project( Assignments.builder() .put(p.symbol("a"), new Reference(BIGINT, "a")) - .put(p.symbol("expression_1"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))) - .put(p.symbol("expression_2"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))) + .put(p.symbol("expression_1"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))) + .put(p.symbol("expression_2"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))) .build(), p.patternRecognition(childBuilder -> { childBuilder @@ -604,7 +601,7 @@ public void testMergeWithParentDependingOnProject() .put("parent_measure", expression(new Reference(BIGINT, "parent_measure"))) .put("child_measure", expression(new Reference(BIGINT, "child_measure"))) .put("expression_1", expression(new Reference(BIGINT, "expression_1"))) - .put("expression_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))) + .put("expression_2", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))))) .buildOrThrow(), patternRecognition(builder -> builder .addMeasure( @@ -628,7 +625,7 @@ public void testMergeWithParentDependingOnProject() ImmutableMap.of( "a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b")), - "expression_1", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), + "expression_1", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))))), values("a", "b")))))); } @@ -655,8 +652,8 @@ public void testOneRowPerMatchMergeWithParentDependingOnProject() Assignments.builder() .put(p.symbol("a"), new Reference(BIGINT, "a")) .put(p.symbol("child_measure"), new Reference(BIGINT, "child_measure")) - .put(p.symbol("expression_1"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "a"))) - .put(p.symbol("expression_2"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "a"))) + .put(p.symbol("expression_1"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "a")))) + .put(p.symbol("expression_2"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "a")))) .build(), p.patternRecognition(childBuilder -> { childBuilder @@ -682,7 +679,7 @@ public void testOneRowPerMatchMergeWithParentDependingOnProject() "a", expression(new Reference(BIGINT, "a")), "parent_measure", expression(new Reference(BIGINT, "parent_measure")), "child_measure", expression(new Reference(BIGINT, "child_measure")), - "expression_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "a")))), + "expression_2", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "a"))))), patternRecognition(builder -> builder .specification(specification(ImmutableList.of("a"), ImmutableList.of(), ImmutableMap.of())) .addMeasure( @@ -706,7 +703,7 @@ public void testOneRowPerMatchMergeWithParentDependingOnProject() ImmutableMap.of( "a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b")), - "expression_1", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "a")))), + "expression_1", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "a"))))), values("a", "b")))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java index 92756a2fa515..258951431ac6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java @@ -19,12 +19,10 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.IsNull; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; import io.trino.sql.planner.Symbol; @@ -38,7 +36,6 @@ import static io.trino.spi.type.CharType.createCharType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.LogicalPlanner.failFunction; @@ -50,6 +47,7 @@ public class TestMergeProjectWithValues private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); private static final ResolvedFunction ADD_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(DOUBLE, DOUBLE)); + private static final ResolvedFunction NEGATION_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(DOUBLE)); @Test public void testDoesNotFireOnNonRowType() @@ -167,34 +165,34 @@ public void testNonDeterministicValues() ImmutableList.of( new Row(ImmutableList.of(new Constant(DOUBLE, null))), new Row(ImmutableList.of(randomFunction)), - new Row(ImmutableList.of(new Negation(randomFunction))))))) + new Row(ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(randomFunction)))))))) .matches( values( ImmutableList.of("output"), ImmutableList.of( ImmutableList.of(new Constant(DOUBLE, null)), ImmutableList.of(randomFunction), - ImmutableList.of(new Negation(randomFunction))))); + ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(randomFunction)))))); // ValuesNode has multiple non-deterministic outputs tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of( - p.symbol("x"), new Negation(new Reference(DOUBLE, "a")), - p.symbol("y"), new Reference(DOUBLE, "b")), + p.symbol("x", DOUBLE), new Call(NEGATION_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "a"))), + p.symbol("y", DOUBLE), new Reference(DOUBLE, "b")), p.valuesOfExpressions( ImmutableList.of(p.symbol("a", DOUBLE), p.symbol("b", DOUBLE)), ImmutableList.of( new Row(ImmutableList.of(new Constant(DOUBLE, 1e0), randomFunction)), new Row(ImmutableList.of(randomFunction, new Constant(DOUBLE, null))), - new Row(ImmutableList.of(new Negation(randomFunction), new Constant(DOUBLE, null))))))) + new Row(ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(randomFunction)), new Constant(DOUBLE, null))))))) .matches( values( ImmutableList.of("x", "y"), ImmutableList.of( - ImmutableList.of(new Negation(new Constant(DOUBLE, 1e0)), randomFunction), - ImmutableList.of(new Negation(randomFunction), new Constant(DOUBLE, null)), - ImmutableList.of(new Negation(new Negation(randomFunction)), new Constant(DOUBLE, null))))); + ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(new Constant(DOUBLE, 1e0))), randomFunction), + ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(randomFunction)), new Constant(DOUBLE, null)), + ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(new Call(NEGATION_DOUBLE, ImmutableList.of(randomFunction)))), new Constant(DOUBLE, null))))); } @Test @@ -216,7 +214,7 @@ public void testDoNotFireOnNonDeterministicValues() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Arithmetic(ADD_DOUBLE, ADD, new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand"))), + Assignments.of(p.symbol("x"), new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand")))), p.valuesOfExpressions( ImmutableList.of(p.symbol("rand")), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) @@ -229,11 +227,11 @@ public void testCorrelation() // correlation symbol in projection (note: the resulting plan is not yet supported in execution) tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x", INTEGER), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Reference(INTEGER, "corr"))), + Assignments.of(p.symbol("x", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Reference(INTEGER, "corr")))), p.valuesOfExpressions( ImmutableList.of(p.symbol("a", INTEGER)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L))))))) - .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Reference(INTEGER, "corr")))))); + .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new Call(ADD_INTEGER, ImmutableList.of(new Constant(INTEGER, 1L), new Reference(INTEGER, "corr"))))))); // correlation symbol in values (note: the resulting plan is not yet supported in execution) tester().assertThat(new MergeProjectWithValues()) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java index 2c5c5cdcf6ed..16f356dcd755 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java @@ -25,7 +25,6 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; @@ -60,8 +59,6 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -154,7 +151,7 @@ public void testPreAggregatesCaseAggregations() .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) + .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) .buildOrThrow(), @@ -171,10 +168,10 @@ public void testPreAggregatesCaseAggregations() exchange( project(ImmutableMap.of( "KEY", expression(new Call(CONCAT, ImmutableList.of(new Reference(VARCHAR, "COL_VARCHAR"), new Constant(VARCHAR, Slices.utf8Slice("a"))))), - "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), - "VALUE_2_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Arithmetic(MULTIPLY_DECIMAL_10_0, MULTIPLY, new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2")))), BIGINT))), Optional.empty()))), + "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), INTEGER), BIGINT))), Optional.empty())), + "VALUE_2_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), + "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Call(MULTIPLY_DECIMAL_10_0, ImmutableList.of(new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2"))))), BIGINT))), Optional.empty()))), tableScan( "t", ImmutableMap.of( @@ -214,7 +211,7 @@ public void testGlobalPreAggregatesCaseAggregations() .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) + .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) .buildOrThrow(), @@ -230,10 +227,10 @@ public void testGlobalPreAggregatesCaseAggregations() SINGLE, exchange( project(ImmutableMap.of( - "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), - "VALUE_2_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Arithmetic(MULTIPLY_DECIMAL_10_0, MULTIPLY, new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2")))), BIGINT))), Optional.empty()))), + "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), INTEGER), BIGINT))), Optional.empty())), + "VALUE_2_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), + "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Call(MULTIPLY_DECIMAL_10_0, ImmutableList.of(new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2"))))), BIGINT))), Optional.empty()))), tableScan( "t", ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java index 8746130ee67a..907c0c0e4e5f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -25,7 +25,6 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -55,7 +54,6 @@ import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; @@ -207,14 +205,14 @@ public void testPushDownDereferenceThroughJoin() p.join(INNER, p.values(p.symbol("msg1", ROW_TYPE)), p.values(p.symbol("msg2", ROW_TYPE)), - new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new FieldReference(new Reference(ROW_TYPE, "msg1"), 0), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)), new Constant(BIGINT, 10L))))) + new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1))), new Constant(BIGINT, 10L))))) .matches( project( ImmutableMap.of( "expr", expression(new Reference(BIGINT, "msg1_x")), "expr_2", expression(new Reference(ROW_TYPE, "msg2"))), join(INNER, builder -> builder - .filter(new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)), new Constant(BIGINT, 10L))) + .filter(new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1))), new Constant(BIGINT, 10L))) .left( strictProject( ImmutableMap.of( @@ -734,25 +732,24 @@ public void testMultiLevelPushdown() p.project( Assignments.of( p.symbol("expr_1"), new FieldReference(new Reference(complexType, "a"), 0), - p.symbol("expr_2"), new Arithmetic( + p.symbol("expr_2"), new Call( ADD_BIGINT, - ADD, - new Arithmetic( - ADD_BIGINT, - ADD, - new Arithmetic( + ImmutableList.of( + new Call( ADD_BIGINT, - ADD, + ImmutableList.of( + new Call( + ADD_BIGINT, + ImmutableList.of( + new FieldReference(new FieldReference(new Reference(complexType, "a"), 0), 0), + new Constant(BIGINT, 2L))), + new FieldReference( + new FieldReference(new Reference(complexType, "b"), 0), 0))), + new FieldReference( new FieldReference( - new FieldReference(new Reference(complexType, "a"), 0), + new Reference(complexType, "b"), 0), - new Constant(BIGINT, 2L)), - new FieldReference( - new FieldReference(new Reference(complexType, "b"), 0), - 0)), - new FieldReference( - new FieldReference(new Reference(complexType, "b"), 0), - 1))), + 1)))), p.project( Assignments.identity(ImmutableList.of(p.symbol("a", complexType), p.symbol("b", complexType))), p.values(p.symbol("a", complexType), p.symbol("b", complexType))))) @@ -760,7 +757,7 @@ public void testMultiLevelPushdown() strictProject( ImmutableMap.of( "expr_1", expression(new Reference(complexType.getFields().get(0).getType(), "a_f1")), - "expr_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new FieldReference(new Reference(complexType.getFields().get(0).getType(), "a_f1"), 0), new Constant(BIGINT, 2L)), new Reference(BIGINT, "b_f1_f1")), new Reference(BIGINT, "b_f1_f2")))), + "expr_2", expression(new Call(ADD_BIGINT, ImmutableList.of(new Call(ADD_BIGINT, ImmutableList.of(new Call(ADD_BIGINT, ImmutableList.of(new FieldReference(new Reference(complexType.getFields().get(0).getType(), "a_f1"), 0), new Constant(BIGINT, 2L))), new Reference(BIGINT, "b_f1_f1"))), new Reference(BIGINT, "b_f1_f2"))))), strictProject( ImmutableMap.of( "a", expression(new Reference(complexType, "a")), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java index 099a6103e995..75d27d67c702 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java @@ -20,7 +20,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -42,8 +41,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; @@ -82,7 +79,7 @@ public void testDoNotPushRuntimeEvaluatedArguments() .addVariableDefinition( new IrLabel("X"), new Comparison(GREATER_THAN, new Call(MAX_BY_BIGINT_VARCHAR, ImmutableList.of( - new Arithmetic(ADD_BIGINT, ADD, new Constant(BIGINT, 1L), new Reference(BIGINT, "match")), + new Call(ADD_BIGINT, ImmutableList.of(new Constant(BIGINT, 1L), new Reference(BIGINT, "match"))), new Call(CONCAT, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("x")), new Reference(VARCHAR, "classifier"))))), new Constant(BIGINT, 5L)), ImmutableMap.of( @@ -118,7 +115,7 @@ public void testPreProjectArguments() ImmutableMap.of(new Symbol(BIGINT, "agg"), new AggregationValuePointer( maxBy, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 2L))), + ImmutableList.of(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 2L)))), Optional.empty(), Optional.empty()))) .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) @@ -136,8 +133,8 @@ public void testPreProjectArguments() Optional.empty()))), project( ImmutableMap.of( - "expr_1", PlanMatchPattern.expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))), - "expr_2", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 2L))), + "expr_1", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)))), + "expr_2", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 2L)))), "a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), "b", PlanMatchPattern.expression(new Reference(BIGINT, "b"))), values("a", "b")))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java index 819a2cd5bdb1..4d5d28ef6b0a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Logical; @@ -34,7 +34,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -266,7 +265,7 @@ public void testPushDownMaskAndSimplifyFilter() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "count"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "count"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 0L)))), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -274,7 +273,7 @@ public void testPushDownMaskAndSimplifyFilter() }) .matches( filter( - new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "count"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "count"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 0L)), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java index 17369236c89b..4ad0d4212ada 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; @@ -30,7 +30,6 @@ import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; @@ -92,7 +91,7 @@ public void testJoinFilterExpressionPushedDownToRightJoinSource() .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "a"))) .left(values("a")) .right(project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), values("b"))))); } @@ -118,8 +117,8 @@ public void testManyJoinFilterExpressionsPushedDownToRightJoinSource() .right( project( ImmutableMap.of( - "expr_less", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))), - "expr_greater", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 10L)))), + "expr_less", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + "expr_greater", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 10L))))), values("b"))))); } @@ -138,11 +137,11 @@ public void testOnlyRightJoinFilterExpressionPushedDownToRightJoinSource() }) .matches( join(INNER, builder -> builder - .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Constant(BIGINT, 2L)))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L))))) .left(values("a")) .right( project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), values("b"))))); } @@ -186,7 +185,7 @@ public void testParentFilterExpressionPushedDownToRightJoinSource() values("a")) .right( project( - ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), values("b"))))))); } @@ -214,8 +213,8 @@ public void testManyParentFilterExpressionsPushedDownToRightJoinSource() .right( project( ImmutableMap.of( - "expr_less", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))), - "expr_greater", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 10L)))), + "expr_less", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + "expr_greater", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 10L))))), values("b"))))))); } @@ -244,8 +243,8 @@ public void testOnlyParentFilterExpressionExposedInaJoin() .right( project( ImmutableMap.of( - "join_expression", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 2L))), - "parent_expression", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + "join_expression", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 2L)))), + "parent_expression", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))))), values("b")))) .withExactOutputs("a", "b", "parent_expression")))); } @@ -285,12 +284,8 @@ private static Comparison comparison(Operator operator, Expression left, Express return new Comparison(operator, left, right); } - private Arithmetic add(Symbol symbol, long value) + private Call add(Symbol symbol, long value) { - return new Arithmetic( - ADD_BIGINT, - ADD, - symbol.toSymbolReference(), - new Constant(BIGINT, value)); + return new Call(ADD_BIGINT, ImmutableList.of(symbol.toSymbolReference(), new Constant(BIGINT, value))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java index a2ac73cf8b12..50e7fe7719a8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java @@ -37,7 +37,6 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; @@ -60,7 +59,6 @@ import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; @@ -277,7 +275,7 @@ public void testPushJoinIntoTableScanWithComplexFilter() right, new Comparison( Comparison.Operator.GREATER_THAN, - new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 44L), columnA1Symbol.toSymbolReference()), + new io.trino.sql.ir.Call(MULTIPLY_BIGINT, ImmutableList.of(new Constant(BIGINT, 44L), columnA1Symbol.toSymbolReference())), columnB1Symbol.toSymbolReference())); }) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java index bcc70b52fbbc..eb32a25409f3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -19,7 +19,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; @@ -30,7 +30,6 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.limit; @@ -101,14 +100,14 @@ public void testPushdownLimitWithTiesThroughProjectionWithExpression() p.project( Assignments.of( projectedA, new Reference(BIGINT, "a"), - projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), + projectedC, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), p.values(a, b))); }) .matches( project( ImmutableMap.of( "projectedA", expression(new Reference(BIGINT, "a")), - "projectedC", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), + "projectedC", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))))), limit(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -127,7 +126,7 @@ public void testDoNotPushdownLimitWithTiesThroughProjectionWithExpression() p.project( Assignments.of( projectedA, new Reference(BIGINT, "a"), - projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), + projectedC, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), p.values(a, b))); }) .doesNotFire(); @@ -181,7 +180,7 @@ public void testLimitWithPreSortedInputs() p.project( Assignments.of( projectedA, new Reference(BIGINT, "a"), - projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), + projectedC, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), p.values(a, b))); }) .doesNotFire(); @@ -200,12 +199,12 @@ projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Refe p.project( Assignments.of( projectedA, new Reference(BIGINT, "a"), - projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), + projectedC, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), p.values(a, b))); }) .matches( project( - ImmutableMap.of("projectedA", expression(new Reference(BIGINT, "a")), "projectedC", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), + ImmutableMap.of("projectedA", expression(new Reference(BIGINT, "a")), "projectedC", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))))), limit(1, ImmutableList.of(), true, ImmutableList.of("a"), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java index b513c78b5b87..5b9c0c99523f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java @@ -26,7 +26,6 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TestingColumnHandle; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; @@ -110,7 +109,7 @@ public void testPushUpdateIntoConnectorArithmeticExpression() Symbol rowCount = p.symbol("row_count"); // set arithmetic expression which we don't support yet Expression updateMergeRowExpression = new Row(ImmutableList.of(p.symbol("column_1").toSymbolReference(), - new Arithmetic(MULTIPLY_BIGINT, Arithmetic.Operator.MULTIPLY, p.symbol("col1").toSymbolReference(), new Constant(BIGINT, 5L)))); + new Call(MULTIPLY_BIGINT, ImmutableList.of(p.symbol("col1").toSymbolReference(), new Constant(BIGINT, 5L))))); return p.tableFinish( p.merge( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java index 76146f807e9a..4bd6c2747a13 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java @@ -39,7 +39,7 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -60,7 +60,6 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.ir.Logical.Operator.OR; @@ -229,7 +228,7 @@ public void testDoesNotConsumeRemainingPredicateIfNewDomainIsWider() new Constant(BOOLEAN, null), new Comparison( EQUAL, - new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L))), new Constant(BIGINT, 44L))), Logical.or( new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), @@ -251,7 +250,7 @@ public void testDoesNotConsumeRemainingPredicateIfNewDomainIsWider() new Constant(DOUBLE, 42.0)), new Comparison( EQUAL, - new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L))), new Constant(BIGINT, 44L))), constrainedTableScanWithTableLayout( "nation", @@ -284,7 +283,7 @@ public void testDoesNotFireIfRuleNotChangePlan() { tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), new Constant(BIGINT, 44L)), new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 15L)), new Constant(BIGINT, 43L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L))), new Constant(BIGINT, 44L)), new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 15L))), new Constant(BIGINT, 43L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java index 10dd81c3b44d..0f191faa09a4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java @@ -19,7 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Logical; @@ -34,7 +34,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; @@ -239,7 +238,7 @@ public void testPredicatePartiallySatisfied() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "row_number"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "row_number"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)))), p.project( Assignments.identity(rowNumber), p.rowNumber( @@ -249,7 +248,7 @@ public void testPredicatePartiallySatisfied() p.values(a)))); }) .matches(filter( - new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "row_number"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)), + new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "row_number"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)), project( ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "row_number"))), rowNumber( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java index fa459e791289..a3321fd0637a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java @@ -19,7 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Logical; @@ -40,7 +40,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; @@ -266,7 +265,7 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "ranking"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "ranking"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)))), p.project( Assignments.identity(ranking), p.window( @@ -277,7 +276,7 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking p.values(a)))); }) .matches(filter( - new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "ranking"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)), + new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "ranking"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)), project( ImmutableMap.of("ranking", expression(new Reference(BIGINT, "ranking"))), topNRanking( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index 0e106b81fe7a..c566150bbba4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -19,7 +19,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.SortOrder; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; @@ -30,7 +30,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -130,7 +129,7 @@ public void testHashMapping() Symbol cTimes5 = p.symbol("c_times_5"); return p.project( Assignments.of( - cTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "c"), new Constant(INTEGER, 5L))), + cTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "c"), new Constant(INTEGER, 5L)))), p.exchange(e -> e .addSource( p.values(a, h1)) @@ -147,7 +146,7 @@ cTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "c"), ImmutableMap.of( "a", expression(new Reference(INTEGER, "a")), "h_1", expression(new Reference(BIGINT, "h_1")), - "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + "a_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))))), values(ImmutableList.of("a", "h_1")))))); } @@ -166,7 +165,7 @@ public void testSkipIdentityProjectionIfOutputPresent() Symbol aTimes5 = p.symbol("a_times_5", INTEGER); return p.project( Assignments.of( - aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), + aTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))), a, a.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -176,7 +175,7 @@ aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), .matches( exchange( strictProject( - ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))))), values(ImmutableList.of("a"))))); // In the following example, the Projection over Exchange has got an identity assignment (b -> b). @@ -192,7 +191,7 @@ aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), Symbol b = p.symbol("b", BIGINT); return p.project( Assignments.of( - bTimes5, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 5L)), + bTimes5, new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 5L))), b, b.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -202,7 +201,7 @@ bTimes5, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), n .matches( exchange( strictProject( - ImmutableMap.of("a_0", expression(new Reference(BIGINT, "a")), "a_times_5", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Constant(BIGINT, 5L)))), + ImmutableMap.of("a_0", expression(new Reference(BIGINT, "a")), "a_times_5", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 5L))))), values(ImmutableList.of("a"))))); } @@ -221,7 +220,7 @@ public void testDoNotSkipIdentityProjectionIfOutputAbsent() Symbol aTimes5 = p.symbol("a_times_5", INTEGER); return p.project( Assignments.of( - aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), + aTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))), a, a.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -231,7 +230,7 @@ aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), .matches( exchange( strictProject( - ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))))), values(ImmutableList.of("a"))))); // In the following example, the Projection over Exchange has got an identity assignment (b -> b). @@ -247,7 +246,7 @@ aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), Symbol b = p.symbol("b", BIGINT); return p.project( Assignments.of( - bTimes5, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 5L)), + bTimes5, new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 5L))), b, b.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -259,7 +258,7 @@ bTimes5, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), n strictProject( ImmutableMap.of( "a_0", expression(new Reference(BIGINT, "a")), - "a_times_5", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Constant(BIGINT, 5L)))), + "a_times_5", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 5L))))), values(ImmutableList.of("a"))))); } @@ -276,9 +275,9 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() Symbol hTimes5 = p.symbol("h_times_5", INTEGER); return p.project( Assignments.builder() - .put(aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))) - .put(bTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L))) - .put(hTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))) + .put(aTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))) + .put(bTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)))) + .put(hTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "h"), new Constant(INTEGER, 5L)))) .build(), p.exchange(e -> e .addSource( @@ -298,9 +297,9 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() ).withNumberOfOutputColumns(5) .withAlias("b", expression(new Reference(INTEGER, "b"))) .withAlias("h", expression(new Reference(INTEGER, "h"))) - .withAlias("a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))) - .withAlias("b_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)))) - .withAlias("h_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))))) + .withAlias("a_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))))) + .withAlias("b_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 5L))))) + .withAlias("h_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "h"), new Constant(INTEGER, 5L)))))) ).withNumberOfOutputColumns(3) .withExactOutputs("a_times_5", "b_times_5", "h_times_5")); } @@ -320,9 +319,9 @@ public void testOrderingColumnsArePreserved() OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(sortSymbol), ImmutableMap.of(sortSymbol, SortOrder.ASC_NULLS_FIRST)); return p.project( Assignments.builder() - .put(aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))) - .put(bTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L))) - .put(hTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))) + .put(aTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))) + .put(bTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)))) + .put(hTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "h"), new Constant(INTEGER, 5L)))) .build(), p.exchange(e -> e .addSource( @@ -339,9 +338,9 @@ public void testOrderingColumnsArePreserved() values( ImmutableList.of("a", "b", "h", "sortSymbol"))) .withNumberOfOutputColumns(4) - .withAlias("a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))) - .withAlias("b_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)))) - .withAlias("h_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L)))) + .withAlias("a_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))))) + .withAlias("b_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 5L))))) + .withAlias("h_times_5", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))))) .withAlias("sortSymbol", expression(new Reference(INTEGER, "sortSymbol")))) ).withNumberOfOutputColumns(3) .withExactOutputs("a_times_5", "b_times_5", "h_times_5")); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java index 0cee5e499821..643927232728 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java @@ -19,8 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; -import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -41,7 +40,6 @@ import static io.trino.metadata.AbstractMockMetadata.dummyMetadata; import static io.trino.metadata.FunctionManager.createTestingFunctionManager; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.assertions.PlanAssert.assertPlan; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -59,6 +57,7 @@ public class TestPushProjectionThroughJoin { private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); @Test public void testPushesProjectionThroughJoin() @@ -75,14 +74,14 @@ public void testPushesProjectionThroughJoin() ProjectNode planNode = p.project( Assignments.of( - a3, new Negation(a2.toSymbolReference()), - b2, new Negation(b1.toSymbolReference())), + a3, new Call(NEGATION_BIGINT, ImmutableList.of(a2.toSymbolReference())), + b2, new Call(NEGATION_BIGINT, ImmutableList.of(b1.toSymbolReference()))), p.join( INNER, // intermediate non-identity projections should be fully inlined p.project( Assignments.of( - a2, new Negation(a0.toSymbolReference()), + a2, new Call(NEGATION_BIGINT, ImmutableList.of(a0.toSymbolReference())), a1, a1.toSymbolReference()), p.project( Assignments.builder() @@ -106,7 +105,7 @@ a2, new Negation(a0.toSymbolReference()), .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol(UNKNOWN, "a1"), new Symbol(UNKNOWN, "b1")))) .left( strictProject(ImmutableMap.of( - "a3", expression(new Negation(new Negation(new Reference(BIGINT, "a0")))), + "a3", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "a0")))))), "a1", expression(new Reference(BIGINT, "a1"))), strictProject(ImmutableMap.of( "a0", expression(new Reference(BIGINT, "a0")), @@ -114,7 +113,7 @@ a2, new Negation(a0.toSymbolReference()), PlanMatchPattern.values("a0", "a1")))) .right( strictProject(ImmutableMap.of( - "b2", expression(new Negation(new Reference(BIGINT, "b1"))), + "b2", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "b1")))), "b1", expression(new Reference(BIGINT, "b1"))), PlanMatchPattern.values("b0", "b1")))) .withExactOutputs("a3", "b2")); @@ -130,7 +129,7 @@ public void testDoesNotPushStraddlingProjection() ProjectNode planNode = p.project( Assignments.of( - c, new Arithmetic(ADD_BIGINT, ADD, a.toSymbolReference(), b.toSymbolReference())), + c, new Call(ADD_BIGINT, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference()))), p.join( INNER, p.values(a), @@ -149,7 +148,7 @@ public void testDoesNotPushProjectionThroughOuterJoin() ProjectNode planNode = p.project( Assignments.of( - c, new Negation(a.toSymbolReference())), + c, new Call(NEGATION_BIGINT, ImmutableList.of(a.toSymbolReference()))), p.join( LEFT, p.values(a), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index eedecbf42db7..5b9b358470a3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -20,7 +20,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; @@ -33,7 +33,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.union; @@ -95,7 +94,7 @@ public void test() Symbol w = p.symbol("w", ROW_TYPE); return p.project( Assignments.of( - cTimes3, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, c.toSymbolReference(), new Constant(BIGINT, 3L)), + cTimes3, new Call(MULTIPLY_BIGINT, ImmutableList.of(c.toSymbolReference(), new Constant(BIGINT, 3L))), dX, new FieldReference(new Reference(ROW_TYPE, "d"), 0)), p.union( ImmutableListMultimap.builder() @@ -111,10 +110,10 @@ dX, new FieldReference(new Reference(ROW_TYPE, "d"), 0)), .matches( union( project( - ImmutableMap.of("a_times_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Constant(BIGINT, 3L))), "z_x", expression(new FieldReference(new Reference(ROW_TYPE, "z"), 0))), + ImmutableMap.of("a_times_3", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 3L)))), "z_x", expression(new FieldReference(new Reference(ROW_TYPE, "z"), 0))), values(ImmutableList.of("a", "z"))), project( - ImmutableMap.of("b_times_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 3L))), "w_x", expression(new FieldReference(new Reference(ROW_TYPE, "w"), 0))), + ImmutableMap.of("b_times_3", expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "b"), new Constant(BIGINT, 3L)))), "w_x", expression(new FieldReference(new Reference(ROW_TYPE, "w"), 0))), values(ImmutableList.of("b", "w")))) .withNumberOfOutputColumns(2) .withAlias("a_times_3") diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java index 73ff5d53ecba..23489a1e85e1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java @@ -19,8 +19,8 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Call; import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; @@ -32,7 +32,6 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; @@ -88,14 +87,14 @@ public void testPushdownTopNNonIdentityProjectionWithExpression() p.project( Assignments.of( projectedA, new Reference(BIGINT, "a"), - projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), + projectedC, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), p.values(a, b))); }) .matches( project( ImmutableMap.of( "projectedA", expression(new Reference(BIGINT, "a")), - "projectedC", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), + "projectedC", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))))), topN(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java index 312fc2061c72..b8a6c7c6d705 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java @@ -28,7 +28,7 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.In; @@ -44,7 +44,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; @@ -172,7 +171,7 @@ public void doesNotConsumeRemainingPredicateIfNewDomainIsWider() new Constant(DOUBLE, 42.0)), new Comparison( EQUAL, - new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L))), new Constant(BIGINT, 44L)), Logical.or( new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), @@ -194,7 +193,7 @@ public void doesNotConsumeRemainingPredicateIfNewDomainIsWider() new Constant(DOUBLE, 42.0)), new Comparison( EQUAL, - new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L))), new Constant(BIGINT, 44L))), constrainedTableScanWithTableLayout( "nation", @@ -227,7 +226,7 @@ public void doesNotFireIfRuleNotChangePlan() { tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 17L)), new Constant(INTEGER, 44L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 15L)), new Constant(INTEGER, 43L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 17L))), new Constant(INTEGER, 44L)), new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 15L))), new Constant(INTEGER, 43L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java index 1873b024e66c..bfeb4ed3a898 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java @@ -18,10 +18,12 @@ import io.trino.cost.CostComparator; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.cost.SymbolStatsEstimate; +import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; @@ -66,6 +68,9 @@ @Execution(CONCURRENT) public class TestReorderJoins { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); + private RuleTester tester; @BeforeAll @@ -422,7 +427,7 @@ public void testPushesProjectionsThroughJoin() INNER, p.project( Assignments.of( - p.symbol("P1"), new Negation(p.symbol("B1").toSymbolReference()), + p.symbol("P1"), new Call(NEGATION_BIGINT, ImmutableList.of(p.symbol("B1").toSymbolReference())), p.symbol("P2"), p.symbol("A1").toSymbolReference()), p.join( INNER, @@ -450,7 +455,7 @@ public void testPushesProjectionsThroughJoin() values("A1"))) .right( strictProject( - ImmutableMap.of("P1", expression(new Negation(new Reference(BIGINT, "B1")))), + ImmutableMap.of("P1", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "B1"))))), values("B1"))))))); } @@ -476,7 +481,7 @@ public void testDoesNotPushProjectionThroughJoinIfTooExpensive() INNER, p.project( Assignments.of( - p.symbol("P1"), new Negation(p.symbol("B1").toSymbolReference())), + p.symbol("P1"), new Call(NEGATION_BIGINT, ImmutableList.of(p.symbol("B1").toSymbolReference()))), p.join( INNER, p.values(new PlanNodeId("valuesA"), 2, p.symbol("A1")), @@ -496,7 +501,7 @@ public void testDoesNotPushProjectionThroughJoinIfTooExpensive() .left(values("C1")) .right( strictProject( - ImmutableMap.of("P1", expression(new Negation(new Reference(BIGINT, "B1")))), + ImmutableMap.of("P1", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "B1"))))), join(INNER, rightJoinBuilder -> rightJoinBuilder .equiCriteria("A1", "B1") .left(values("A1")) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java index 00086e2b5b94..b926275b3899 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -19,7 +19,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -48,7 +48,6 @@ import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; @@ -437,7 +436,7 @@ public void testCastDoubleToBoundedVarchar() new Cast(new Constant(DOUBLE, -0.0), createVarcharType(4)), new Constant(createVarcharType(4), Slices.utf8Slice("-0E0"))); assertSimplifies( - new Cast(new Arithmetic(DIVIDE_DOUBLE, DIVIDE, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 0.0)), createVarcharType(3)), + new Cast(new Call(DIVIDE_DOUBLE, ImmutableList.of(new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 0.0))), createVarcharType(3)), new Constant(createVarcharType(3), Slices.utf8Slice("NaN"))); assertSimplifies( new Cast(new Constant(DOUBLE, Double.POSITIVE_INFINITY), createVarcharType(8)), @@ -478,7 +477,7 @@ public void testCastRealToBoundedVarchar() new Cast(new Constant(REAL, Reals.toReal(-0.0f)), createVarcharType(4)), new Constant(createVarcharType(4), Slices.utf8Slice("-0E0"))); assertSimplifies( - new Cast(new Arithmetic(DIVIDE_REAL, DIVIDE, new Constant(REAL, Reals.toReal(0.0f)), new Constant(REAL, Reals.toReal(0.0f))), createVarcharType(3)), + new Cast(new Call(DIVIDE_REAL, ImmutableList.of(new Constant(REAL, Reals.toReal(0.0f)), new Constant(REAL, Reals.toReal(0.0f)))), createVarcharType(3)), new Constant(createVarcharType(3), Slices.utf8Slice("NaN"))); assertSimplifies( new Cast(new Constant(REAL, Reals.toReal(Float.POSITIVE_INFINITY)), createVarcharType(8)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java index b9cc0057b5b9..febbe2fbaded 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java @@ -17,7 +17,6 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Call; import io.trino.sql.ir.Case; import io.trino.sql.ir.Comparison; @@ -37,7 +36,6 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; @@ -407,7 +405,7 @@ public void testSimplifySimpleCaseExpression() new Reference(BOOLEAN, "a"), ImmutableList.of( new WhenClause(new Reference(BOOLEAN, "b"), TRUE), - new WhenClause(new Comparison(EQUAL, new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 0L)), FALSE)), + new WhenClause(new Comparison(EQUAL, new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 0L)), FALSE)), Optional.of(TRUE)), p.values(p.symbol("a"), p.symbol("b")))) .doesNotFire(); @@ -448,8 +446,8 @@ public void testSimplifySimpleCaseExpression() new Switch( new Reference(INTEGER, "a"), ImmutableList.of( - new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), TRUE), - new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), TRUE)), + new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), TRUE), + new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), TRUE)), Optional.of(TRUE)), p.values(p.symbol("a"), p.symbol("b")))) .matches( @@ -463,8 +461,8 @@ public void testSimplifySimpleCaseExpression() new Switch( new Reference(INTEGER, "a"), ImmutableList.of( - new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE), - new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), new Constant(BOOLEAN, null))), + new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), FALSE), + new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), new Constant(BOOLEAN, null))), Optional.of(FALSE)), p.values(p.symbol("a"), p.symbol("b")))) .matches( @@ -478,8 +476,8 @@ public void testSimplifySimpleCaseExpression() new Switch( new Reference(INTEGER, "a"), ImmutableList.of( - new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE), - new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), new Constant(BOOLEAN, null))), + new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), FALSE), + new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), new Constant(BOOLEAN, null))), Optional.empty()), p.values(p.symbol("a"), p.symbol("b")))) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java index 1adb23d5801b..9321b192058f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -31,7 +31,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -87,7 +86,7 @@ public void rewritesOnSubqueryWithDistinct() JoinType.LEFT, TRUE, p.project( - Assignments.of(p.symbol("x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 100L))), + Assignments.of(p.symbol("x"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 100L)))), p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( @@ -96,7 +95,7 @@ public void rewritesOnSubqueryWithDistinct() .matches( project(ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), - "x", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 100L)))), + "x", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 100L))))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java index 85c4f22fca7d..bb8ab0dc05aa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Logical; @@ -33,8 +33,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -110,9 +108,9 @@ public void doesNotFireOnMultipleProjections() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("expr_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr_2"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L)))), p.project( - Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -141,13 +139,13 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.project(Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping())))) .matches( - project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "expr", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_1"), new Constant(INTEGER, 1L)))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "expr", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_1"), new Constant(INTEGER, 1L))))), aggregation(ImmutableMap.of("sum_1", aggregationFunction("sum", ImmutableList.of("a"))), join(LEFT, builder -> builder .left(assignUniqueId("unique", @@ -165,8 +163,8 @@ public void rewritesOnSubqueryWithDistinct() p.values(p.symbol("corr")), p.project( Assignments.of( - p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) @@ -179,8 +177,8 @@ public void rewritesOnSubqueryWithDistinct() .matches( project(ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), - "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "expr_sum", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L)))), + "expr_count", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -218,8 +216,8 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() p.values(p.symbol("corr")), p.project( Assignments.of( - p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) @@ -232,8 +230,8 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() .matches( project(ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), - "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "expr_sum", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L)))), + "expr_count", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -267,13 +265,13 @@ public void testWithPreexistingMask() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.project(Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("mask", BOOLEAN))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT), p.symbol("mask", BOOLEAN)) .globalGrouping())))) .matches( - project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "expr", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_1"), new Constant(INTEGER, 1L)))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "expr", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_1"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum_1"), aggregationFunction("sum", ImmutableList.of("a"))), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java index aa3d462d8e18..5aebf76ded0c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Logical; @@ -33,8 +33,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -110,9 +108,9 @@ public void doesNotFireOnMultipleProjections() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("expr_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr_2"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L)))), p.project( - Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -148,7 +146,7 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.project(Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java index 5ebd751c797a..fa0389e8a41d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -32,8 +32,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -94,8 +92,8 @@ public void rewritesOnSubqueryWithoutDistinct() TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -106,8 +104,8 @@ public void rewritesOnSubqueryWithoutDistinct() .matches( project(ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), - "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "expr_sum", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L)))), + "expr_count", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -136,8 +134,8 @@ public void rewritesOnSubqueryWithDistinct() TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -150,8 +148,8 @@ public void rewritesOnSubqueryWithDistinct() .matches( project(ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), - "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "expr_sum", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L)))), + "expr_count", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -187,8 +185,8 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -201,8 +199,8 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() .matches( project(ImmutableMap.of( "corr", expression(new Reference(BIGINT, "corr")), - "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "expr_sum", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L)))), + "expr_count", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L))))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index 67bbb84c080a..90e6e6c1bece 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -19,7 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -41,8 +41,6 @@ import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.LogicalPlanner.failFunction; @@ -135,7 +133,7 @@ public void rewritesOnSubqueryWithProjection() p.values(p.symbol("corr")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L))), + Assignments.of(p.symbol("a2"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), p.filter( new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS)))))) @@ -151,7 +149,7 @@ public void rewritesOnSubqueryWithProjection() assignUniqueId( "unique", values("corr")), - project(ImmutableMap.of("a2", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), + project(ImmutableMap.of("a2", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L))))), filter( new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), values("a")))))))); @@ -165,10 +163,10 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("a3"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("a3"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L)))), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L))), + Assignments.of(p.symbol("a2"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), p.filter( new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))))) @@ -185,9 +183,9 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() "unique", values("corr")), project( - ImmutableMap.of("a3", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("a3", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L))))), project( - ImmutableMap.of("a2", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("a2", expression(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L))))), filter( new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), values("a"))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index b67432bcb041..5dd040b1f04f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -21,7 +21,7 @@ import io.trino.plugin.tpch.TpchColumnHandle; import io.trino.spi.function.OperatorType; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -38,7 +38,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -70,14 +69,14 @@ public void testRewrite() ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", BIGINT))), p.project( - Assignments.of(p.symbol("l_expr2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("l_expr2"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L)))), p.values( ImmutableList.of(), ImmutableList.of( ImmutableList.of()))))) .matches(project( ImmutableMap.of( - "l_expr2", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L))), + "l_expr2", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L)))), "l_nationkey", PlanMatchPattern.expression(new Reference(BIGINT, "l_nationkey"))), tableScan("nation", ImmutableMap.of("l_nationkey", "nationkey")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java index 008028252095..ba1010925609 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java @@ -29,7 +29,7 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.BigintType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -65,7 +65,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED; @@ -705,7 +704,7 @@ SELECT suppkey, partkey, count(*) as count Optional.empty(), PARTIAL, project( - ImmutableMap.of("partkey_expr", expression(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "partkey"), new Constant(BIGINT, 10L)))), + ImmutableMap.of("partkey_expr", expression(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "partkey"), new Constant(BIGINT, 10L))))), tableScan("lineitem", ImmutableMap.of( "partkey", "partkey", "suppkey", "suppkey")))))))))))))); @@ -735,7 +734,7 @@ SELECT suppkey, partkey, count(*) as count Optional.empty(), Step.PARTIAL, project( - ImmutableMap.of("orderkey_expr", expression(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 10000L)))), + ImmutableMap.of("orderkey_expr", expression(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 10000L))))), tableScan("lineitem", ImmutableMap.of( "partkey", "partkey", "orderkey", "orderkey", diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java index ca561070ee9e..d0b808be55ff 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java @@ -18,7 +18,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -27,7 +27,6 @@ import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; @@ -173,7 +172,7 @@ public void testComplexArgumentToCoalesce() ImmutableMap.of(), PARTIAL, project( - ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "l"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "m"), new Constant(INTEGER, 1L)), new Reference(INTEGER, "r")))), + ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "l"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "m"), new Constant(INTEGER, 1L))), new Reference(INTEGER, "r")))), join(FULL, builder -> builder .equiCriteria("l", "r") .left( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java index 4247dd0240f5..ff1572c3d2bd 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java @@ -28,7 +28,7 @@ import io.trino.spi.connector.CatalogHandle; import io.trino.spi.function.OperatorType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -64,7 +64,6 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrUtils.combineConjuncts; @@ -368,7 +367,7 @@ public void testSpatialJoin() builder.values(leftSymbol), builder.values(rightSymbol), ImmutableList.of(leftSymbol, rightSymbol), - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LEFT_SYMBOL"), new Reference(BIGINT, "RIGHT_SYMBOL"))))); + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "LEFT_SYMBOL"), new Reference(BIGINT, "RIGHT_SYMBOL")))))); assertPlan( removeUnsupportedDynamicFilters(root), output( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java index 14468928110f..29a88b998769 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java @@ -25,6 +25,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.block.Block; import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.function.OperatorType; import io.trino.spi.type.TestingTypeManager; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -32,7 +33,6 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolKeyDeserializer; @@ -74,6 +74,8 @@ public class TestPatternRecognitionNodeSerialization { private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); + private static final ResolvedFunction NEGATION_INTEGER = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(INTEGER)); private static final JsonCodec VALUE_POINTER_CODEC; private static final JsonCodec EXPRESSION_AND_VALUE_POINTERS_CODEC; @@ -141,7 +143,7 @@ public void testExpressionAndValuePointersRoundtrip() ifExpression( new Comparison(GREATER_THAN, new Reference(VARCHAR, "classifier"), new Reference(VARCHAR, "x")), new Cast(new Call(RANDOM, ImmutableList.of()), INTEGER), - new Negation(new Reference(INTEGER, "match_number"))), + new Call(NEGATION_INTEGER, ImmutableList.of(new Reference(INTEGER, "match_number")))), ImmutableList.of( new ExpressionAndValuePointers.Assignment( new Symbol(VARCHAR, "classifier"), @@ -169,7 +171,7 @@ public void testMeasureRoundtrip() ifExpression( new Comparison(GREATER_THAN, new Reference(INTEGER, "match_number"), new Reference(INTEGER, "x")), new Constant(BIGINT, 10L), - new Negation(new Reference(BIGINT, "y"))), + new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "y")))), ImmutableList.of( new ExpressionAndValuePointers.Assignment( new Symbol(BIGINT, "match_number"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java index 730e4902a8ce..438e3b0fbe76 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java @@ -26,7 +26,7 @@ import io.trino.spi.connector.CatalogHandle; import io.trino.spi.function.OperatorType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -51,7 +51,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.combineDisjuncts; @@ -226,7 +225,7 @@ public void testUnsupportedDynamicFilterExpression() builder.join( INNER, builder.filter( - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L))), + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), lineitemTableScanNode), ordersTableScanNode, ImmutableList.of(new JoinNode.EquiJoinClause(lineitemOrderKeySymbol, ordersOrderKeySymbol)), diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java b/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java index 13d45030c442..e4c120a7179c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java @@ -20,7 +20,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.plugin.tpch.TpchPlugin; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -42,8 +42,6 @@ import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; -import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; @@ -247,7 +245,7 @@ public void testCorrelatedSubqueriesWithTopN() .equiCriteria("expr", "a") .left( project( - ImmutableMap.of("expr", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Reference(INTEGER, "c")), new Constant(INTEGER, 1L)))), + ImmutableMap.of("expr", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Reference(INTEGER, "c"))), new Constant(INTEGER, 1L))))), any( values("b", "c")))) .right( diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java index 6e451ff4d867..bd96e9f46a1e 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java @@ -30,7 +30,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -59,7 +59,6 @@ import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; @@ -200,7 +199,7 @@ public void testDereferencePushdown() filter( new Logical(AND, ImmutableList.of( new Comparison(EQUAL, new Reference(BIGINT, "y"), new Constant(BIGINT, 2L)), - new Comparison(EQUAL, new Reference(BIGINT, "x"), new Cast(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Cast(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L))), BIGINT)))), source2))); // Projection and predicate pushdown with overlapping columns diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java index 77b8dde6e3f8..11854529a8c6 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java @@ -36,12 +36,11 @@ import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.FieldReference; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; @@ -70,7 +69,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -86,6 +84,7 @@ public class TestConnectorPushdownRulesWithHive { private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); private static final String SCHEMA_NAME = "test_schema"; @@ -313,7 +312,7 @@ public void testPushdownWithDuplicateExpressions() tester().assertThat(pushProjectionIntoTableScan) .on(p -> { Reference column = p.symbol("just_bigint", BIGINT).toSymbolReference(); - Expression negation = new Negation(column); + Expression negation = new Call(NEGATION_BIGINT, ImmutableList.of(column)); return p.project( Assignments.of( // The column reference is part of both the assignments @@ -327,7 +326,7 @@ public void testPushdownWithDuplicateExpressions() .matches(project( ImmutableMap.of( "column_ref", expression(new Reference(BIGINT, "just_bigint_0")), - "negated_column_ref", expression(new Negation(new Reference(BIGINT, "just_bigint_0")))), + "negated_column_ref", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "just_bigint_0"))))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(bigintColumn))::equals, TupleDomain.all(), @@ -337,7 +336,7 @@ public void testPushdownWithDuplicateExpressions() tester().assertThat(pushProjectionIntoTableScan) .on(p -> { FieldReference fieldReference = new FieldReference(p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), 0); - Expression sum = new Arithmetic(ADD_BIGINT, ADD, fieldReference, new Constant(BIGINT, 2L)); + Expression sum = new Call(ADD_BIGINT, ImmutableList.of(fieldReference, new Constant(BIGINT, 2L))); return p.project( Assignments.of( // The subscript expression instance is part of both the assignments @@ -351,7 +350,7 @@ public void testPushdownWithDuplicateExpressions() .matches(project( ImmutableMap.of( "expr_deref", expression(new Reference(BIGINT, "struct_of_bigint#a")), - "expr_deref_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "struct_of_bigint#a"), new Constant(BIGINT, 2L)))), + "expr_deref_2", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "struct_of_bigint#a"), new Constant(BIGINT, 2L))))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, TupleDomain.all(), diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java index 7098ee0687ea..d881bf343270 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java @@ -24,7 +24,6 @@ import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.spi.function.OperatorType; import io.trino.spi.security.PrincipalType; -import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; @@ -57,7 +56,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; @@ -281,13 +279,13 @@ public void testSubsumePartitionFilterNotConvertibleToTupleDomain() .left( exchange(REMOTE, REPARTITION, filter( - new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "L_INT_PART"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)), + new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "L_INT_PART"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)), tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)))), + new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java index 095e536d59fa..f643bab995e7 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java @@ -33,7 +33,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -58,7 +58,6 @@ import static io.trino.plugin.hive.TestingHiveUtils.getConnectorService; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; @@ -179,7 +178,7 @@ public void testDereferencePushdown() format("SELECT col0.x FROM %s WHERE col0.x = col1 + 3 and col0.y = 2", testTable), anyTree( filter( - new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "col0_y"), new Constant(BIGINT, 2L)), new Comparison(EQUAL, new Reference(BIGINT, "col0_x"), new Cast(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "col0_y"), new Constant(BIGINT, 2L)), new Comparison(EQUAL, new Reference(BIGINT, "col0_x"), new Cast(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L))), BIGINT)))), tableScan( table -> { HiveTableHandle hiveTableHandle = (HiveTableHandle) table; diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java index 7d079c954167..7bffb6dba579 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java @@ -30,7 +30,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -55,7 +55,6 @@ import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; @@ -193,7 +192,7 @@ public void testDereferencePushdown() format("SELECT col0.x FROM %s WHERE col0.x = col1 + 3 and col0.y = 2", testTable), anyTree( filter( - new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "y"), new Constant(BIGINT, 2L)), new Comparison(EQUAL, new Reference(BIGINT, "x"), new Cast(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "y"), new Constant(BIGINT, 2L)), new Comparison(EQUAL, new Reference(BIGINT, "x"), new Cast(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L))), BIGINT)))), tableScan( table -> { IcebergTableHandle icebergTableHandle = (IcebergTableHandle) table; diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java index f6946df6a5c1..29516fff4b57 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java @@ -39,12 +39,11 @@ import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.FieldReference; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; @@ -71,7 +70,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -88,6 +86,7 @@ public class TestConnectorPushdownRulesWithIceberg { private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); private static final String SCHEMA_NAME = "test_schema"; @@ -391,7 +390,7 @@ public void testPushdownWithDuplicateExpressions() tester().assertThat(pushProjectionIntoTableScan) .on(p -> { Reference column = p.symbol("just_bigint", BIGINT).toSymbolReference(); - Expression negation = new Negation(column); + Expression negation = new Call(NEGATION_BIGINT, ImmutableList.of(column)); return p.project( Assignments.of( // The column reference is part of both the assignments @@ -405,7 +404,7 @@ public void testPushdownWithDuplicateExpressions() .matches(project( ImmutableMap.of( "column_ref", expression(new Reference(BIGINT, "just_bigint_0")), - "negated_column_ref", expression(new Negation(new Reference(BIGINT, "just_bigint_0")))), + "negated_column_ref", expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "just_bigint_0"))))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(bigintColumn))::equals, TupleDomain.all(), @@ -415,7 +414,7 @@ public void testPushdownWithDuplicateExpressions() tester().assertThat(pushProjectionIntoTableScan) .on(p -> { FieldReference fieldReference = new FieldReference(p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), 0); - Expression sum = new Arithmetic(ADD_BIGINT, ADD, fieldReference, new Constant(BIGINT, 2L)); + Expression sum = new Call(ADD_BIGINT, ImmutableList.of(fieldReference, new Constant(BIGINT, 2L))); return p.project( Assignments.of( // The subscript expression instance is part of both the assignments @@ -429,7 +428,7 @@ public void testPushdownWithDuplicateExpressions() .matches(project( ImmutableMap.of( "expr_deref", expression(new Reference(BIGINT, "struct_of_bigint#a")), - "expr_deref_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "struct_of_bigint#a"), new Constant(BIGINT, 2L)))), + "expr_deref_2", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "struct_of_bigint#a"), new Constant(BIGINT, 2L))))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, TupleDomain.all(), diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java index a4c46d1c206e..b8c29a84953c 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java @@ -29,7 +29,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FieldReference; @@ -53,7 +53,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -174,7 +173,7 @@ public void testDereferencePushdown() "SELECT col0.x FROM " + tableName + " WHERE col0.x = col1 + 3 and col0.y = 2", anyTree( filter( - new Comparison(EQUAL, new Reference(BIGINT, "x"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "col1"), new Constant(BIGINT, 3L))), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "col1"), new Constant(BIGINT, 3L)))), tableScan( table -> { MongoTableHandle actualTableHandle = (MongoTableHandle) table; diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index e0c229eaf36c..7c76b3f4b229 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.plugin.base.mapping.DefaultIdentifierMapping; import io.trino.plugin.jdbc.BaseJdbcConfig; @@ -37,14 +38,13 @@ import io.trino.spi.expression.Variable; import io.trino.spi.function.OperatorType; import io.trino.spi.session.PropertyMetadata; -import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; -import io.trino.sql.ir.Negation; import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; @@ -53,12 +53,18 @@ import org.junit.jupiter.api.Test; import java.sql.Types; +import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Optional; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.spi.function.OperatorType.ADD; +import static io.trino.spi.function.OperatorType.DIVIDE; +import static io.trino.spi.function.OperatorType.MODULUS; +import static io.trino.spi.function.OperatorType.MULTIPLY; +import static io.trino.spi.function.OperatorType.SUBTRACT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -70,6 +76,9 @@ public class TestPostgreSqlClient { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); + private static final JdbcColumnHandle BIGINT_COLUMN = JdbcColumnHandle.builder() .setColumnName("c_bigint") @@ -298,23 +307,16 @@ public void testConvertArithmeticBinary() { TestingFunctionResolution resolver = new TestingFunctionResolution(); - for (Arithmetic.Operator operator : Arithmetic.Operator.values()) { + for (OperatorType operator : EnumSet.of(ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS)) { ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new Arithmetic(resolver.resolveOperator( - switch (operator) { - case ADD -> OperatorType.ADD; - case SUBTRACT -> OperatorType.SUBTRACT; - case MULTIPLY -> OperatorType.MULTIPLY; - case DIVIDE -> OperatorType.DIVIDE; - case MODULUS -> OperatorType.MODULUS; - }, - ImmutableList.of(BIGINT, BIGINT)), operator, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L))), + new Call(resolver.resolveOperator( + operator, + ImmutableList.of(BIGINT, BIGINT)), ImmutableList.of(new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)))), Map.of("c_bigint_symbol", BIGINT_COLUMN)) .orElseThrow(); - assertThat(converted.expression()).isEqualTo(format("(\"c_bigint\") %s (?)", operator.getValue())); assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L)))); } } @@ -325,7 +327,7 @@ public void testConvertArithmeticUnaryMinus() ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new Negation(new Reference(BIGINT, "c_bigint_symbol"))), + new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "c_bigint_symbol")))), Map.of("c_bigint_symbol", BIGINT_COLUMN)) .orElseThrow();