diff --git a/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java index 7696a67919f3..92844f89ee3d 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java @@ -13,7 +13,7 @@ */ package io.trino.cost; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.planner.Symbol; import java.util.Optional; @@ -45,7 +45,7 @@ public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( SymbolStatsEstimate expressionStatistics, Optional expressionSymbol, OptionalDouble literalValue, - ComparisonExpression.Operator operator) + Comparison.Operator operator) { switch (operator) { case EQUAL: @@ -160,7 +160,7 @@ public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison( Optional leftExpressionSymbol, SymbolStatsEstimate rightExpressionStatistics, Optional rightExpressionSymbol, - ComparisonExpression.Operator operator) + Comparison.Operator operator) { switch (operator) { case EQUAL: @@ -255,7 +255,7 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( } private static PlanNodeStatsEstimate estimateExpressionToExpressionInequality( - ComparisonExpression.Operator operator, + Comparison.Operator operator, PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate leftExpressionStatistics, Optional leftExpressionSymbol, diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index be0251dfcfdf..032363f29a06 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -20,18 +20,18 @@ import io.trino.Session; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.BooleanLiteral; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; +import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; @@ -58,9 +58,9 @@ import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.DynamicFilters.isDynamicFilter; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.IrExpressionInterpreter.evaluateConstantExpression; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; @@ -148,11 +148,11 @@ protected PlanNodeStatsEstimate visitExpression(Expression node, Void context) } @Override - protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context) + protected PlanNodeStatsEstimate visitNot(Not node, Void context) { - if (node.getValue() instanceof IsNullPredicate inner) { - if (inner.getValue() instanceof SymbolReference) { - Symbol symbol = Symbol.from(inner.getValue()); + if (node.value() instanceof IsNull inner) { + if (inner.value() instanceof Reference) { + Symbol symbol = Symbol.from(inner.value()); SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol); PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input); result.setOutputRowCount(input.getOutputRowCount() * (1 - symbolStats.getNullsFraction())); @@ -161,19 +161,19 @@ protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void cont } return PlanNodeStatsEstimate.unknown(); } - return subtractSubsetStats(input, process(node.getValue())); + return subtractSubsetStats(input, process(node.value())); } @Override - protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, Void context) + protected PlanNodeStatsEstimate visitLogical(Logical node, Void context) { - switch (node.getOperator()) { + switch (node.operator()) { case AND: - return estimateLogicalAnd(node.getTerms()); + return estimateLogicalAnd(node.terms()); case OR: - return estimateLogicalOr(node.getTerms()); + return estimateLogicalOr(node.terms()); } - throw new IllegalArgumentException("Unexpected binary operator: " + node.getOperator()); + throw new IllegalArgumentException("Unexpected binary operator: " + node.operator()); } private PlanNodeStatsEstimate estimateLogicalAnd(List terms) @@ -262,8 +262,8 @@ private PlanNodeStatsEstimate estimateLogicalOr(List terms) @Override protected PlanNodeStatsEstimate visitConstant(Constant node, Void context) { - if (node.getType().equals(BOOLEAN) && node.getValue() != null) { - if ((boolean) node.getValue()) { + if (node.type().equals(BOOLEAN) && node.value() != null) { + if ((boolean) node.value()) { return input; } @@ -277,10 +277,10 @@ protected PlanNodeStatsEstimate visitConstant(Constant node, Void context) } @Override - protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context) + protected PlanNodeStatsEstimate visitIsNull(IsNull node, Void context) { - if (node.getValue() instanceof SymbolReference) { - Symbol symbol = Symbol.from(node.getValue()); + if (node.value() instanceof Reference) { + Symbol symbol = Symbol.from(node.value()); SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol); PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input); result.setOutputRowCount(input.getOutputRowCount() * symbolStats.getNullsFraction()); @@ -296,21 +296,21 @@ protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void } @Override - protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context) + protected PlanNodeStatsEstimate visitBetween(Between node, Void context) { - SymbolStatsEstimate valueStats = getExpressionStats(node.getValue()); + SymbolStatsEstimate valueStats = getExpressionStats(node.value()); if (valueStats.isUnknown()) { return PlanNodeStatsEstimate.unknown(); } - if (!getExpressionStats(node.getMin()).isSingleValue()) { + if (!getExpressionStats(node.min()).isSingleValue()) { return PlanNodeStatsEstimate.unknown(); } - if (!getExpressionStats(node.getMax()).isSingleValue()) { + if (!getExpressionStats(node.max()).isSingleValue()) { return PlanNodeStatsEstimate.unknown(); } - Expression lowerBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()); - Expression upperBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()); + Expression lowerBound = new Comparison(GREATER_THAN_OR_EQUAL, node.value(), node.min()); + Expression upperBound = new Comparison(LESS_THAN_OR_EQUAL, node.value(), node.max()); Expression transformed; if (isInfinite(valueStats.getLowValue())) { @@ -325,10 +325,10 @@ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Voi } @Override - protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) + protected PlanNodeStatsEstimate visitIn(In node, Void context) { - ImmutableList equalityEstimates = node.getValueList().stream() - .map(inValue -> process(new ComparisonExpression(EQUAL, node.getValue(), inValue))) + ImmutableList equalityEstimates = node.valueList().stream() + .map(inValue -> process(new Comparison(EQUAL, node.value(), inValue))) .collect(toImmutableList()); if (equalityEstimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown)) { @@ -343,7 +343,7 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) return PlanNodeStatsEstimate.unknown(); } - SymbolStatsEstimate valueStats = getExpressionStats(node.getValue()); + SymbolStatsEstimate valueStats = getExpressionStats(node.value()); if (valueStats.isUnknown()) { return PlanNodeStatsEstimate.unknown(); } @@ -353,8 +353,8 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input); result.setOutputRowCount(min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn)); - if (node.getValue() instanceof SymbolReference) { - Symbol valueSymbol = Symbol.from(node.getValue()); + if (node.value() instanceof Reference) { + Symbol valueSymbol = Symbol.from(node.value()); SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol) .mapDistinctValuesCount(newDistinctValuesCount -> min(newDistinctValuesCount, valueStats.getDistinctValuesCount())); result.addSymbolStatistics(valueSymbol, newSymbolStats); @@ -364,30 +364,30 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) @SuppressWarnings("ArgumentSelectionDefectChecker") @Override - protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context) + protected PlanNodeStatsEstimate visitComparison(Comparison node, Void context) { - ComparisonExpression.Operator operator = node.getOperator(); - Expression left = node.getLeft(); - Expression right = node.getRight(); + Comparison.Operator operator = node.operator(); + Expression left = node.left(); + Expression right = node.right(); checkArgument(!(left instanceof Constant && right instanceof Constant), "Literal-to-literal not supported here, should be eliminated earlier"); - if (!(left instanceof SymbolReference) && right instanceof SymbolReference) { + if (!(left instanceof Reference) && right instanceof Reference) { // normalize so that symbol is on the left - return process(new ComparisonExpression(operator.flip(), right, left)); + return process(new Comparison(operator.flip(), right, left)); } if (left instanceof Constant) { // normalize so that literal is on the right - return process(new ComparisonExpression(operator.flip(), right, left)); + return process(new Comparison(operator.flip(), right, left)); } - if (left instanceof SymbolReference && left.equals(right)) { - return process(new NotExpression(new IsNullPredicate(left))); + if (left instanceof Reference && left.equals(right)) { + return process(new Not(new IsNull(left))); } SymbolStatsEstimate leftStats = getExpressionStats(left); - Optional leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty(); + Optional leftSymbol = left instanceof Reference ? Optional.of(Symbol.from(left)) : Optional.empty(); if (right instanceof Constant) { Type type = left.type(); Object literalValue = evaluateConstantExpression(right, plannerContext, session); @@ -405,22 +405,22 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, value, operator); } - Optional rightSymbol = right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty(); + Optional rightSymbol = right instanceof Reference ? Optional.of(Symbol.from(right)) : Optional.empty(); return estimateExpressionToExpressionComparison(input, leftStats, leftSymbol, rightStats, rightSymbol, operator); } @Override - protected PlanNodeStatsEstimate visitFunctionCall(FunctionCall node, Void context) + protected PlanNodeStatsEstimate visitCall(Call node, Void context) { if (isDynamicFilter(node)) { - return process(BooleanLiteral.TRUE_LITERAL, context); + return process(Booleans.TRUE, context); } return PlanNodeStatsEstimate.unknown(); } private SymbolStatsEstimate getExpressionStats(Expression expression) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { Symbol symbol = Symbol.from(expression); return requireNonNull(input.getSymbolStatistics(symbol), () -> format("No statistics for symbol %s", symbol)); } diff --git a/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java index 2cd530775750..5dcb936dc404 100644 --- a/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java @@ -17,7 +17,7 @@ import io.trino.Session; import io.trino.cost.StatsCalculator.Context; import io.trino.matching.Pattern; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.JoinNode; @@ -36,7 +36,7 @@ import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount; import static io.trino.cost.SymbolStatsEstimate.buildFrom; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.planner.plan.Patterns.join; import static io.trino.util.MoreMath.firstNonNaN; @@ -183,7 +183,7 @@ private PlanNodeStatsEstimate filterByEquiJoinClauses( // clause separately because stats estimates would be way off. List knownEstimates = clauses.stream() .map(clause -> { - ComparisonExpression predicate = new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()); + Comparison predicate = new Comparison(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()); return new PlanNodeStatsEstimateWithClause(filterStatsCalculator.filterStats(stats, predicate, session), clause); }) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index 456aae04f467..50996634fefc 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -22,15 +22,15 @@ import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; @@ -80,7 +80,7 @@ protected SymbolStatsEstimate visitExpression(Expression node, Void context) } @Override - protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context) + protected SymbolStatsEstimate visitReference(Reference node, Void context) { return input.getSymbolStatistics(Symbol.from(node)); } @@ -88,8 +88,8 @@ protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void co @Override protected SymbolStatsEstimate visitConstant(Constant node, Void context) { - Type type = node.getType(); - Object value = node.getValue(); + Type type = node.type(); + Object value = node.value(); if (value == null) { return nullStatsEstimate(); } @@ -107,7 +107,7 @@ protected SymbolStatsEstimate visitConstant(Constant node, Void context) } @Override - protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) + protected SymbolStatsEstimate visitCall(Call node, Void context) { IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session); Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); @@ -131,7 +131,7 @@ protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) @Override protected SymbolStatsEstimate visitCast(Cast node, Void context) { - SymbolStatsEstimate sourceStats = process(node.getExpression()); + SymbolStatsEstimate sourceStats = process(node.expression()); // todo - make this general postprocessing rule. double distinctValuesCount = sourceStats.getDistinctValuesCount(); @@ -176,9 +176,9 @@ private boolean isIntegralType(Type type) } @Override - protected SymbolStatsEstimate visitArithmeticNegation(ArithmeticNegation node, Void context) + protected SymbolStatsEstimate visitNegation(Negation node, Void context) { - SymbolStatsEstimate stats = process(node.getValue()); + SymbolStatsEstimate stats = process(node.value()); return SymbolStatsEstimate.buildFrom(stats) .setLowValue(-stats.getHighValue()) .setHighValue(-stats.getLowValue()) @@ -186,11 +186,11 @@ protected SymbolStatsEstimate visitArithmeticNegation(ArithmeticNegation node, V } @Override - protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + protected SymbolStatsEstimate visitArithmetic(Arithmetic node, Void context) { requireNonNull(node, "node is null"); - SymbolStatsEstimate left = process(node.getLeft()); - SymbolStatsEstimate right = process(node.getRight()); + SymbolStatsEstimate left = process(node.left()); + SymbolStatsEstimate right = process(node.right()); if (left.isUnknown() || right.isUnknown()) { return SymbolStatsEstimate.unknown(); } @@ -208,11 +208,11 @@ protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression n result.setLowValue(NaN) .setHighValue(NaN); } - else if (node.getOperator() == ArithmeticBinaryExpression.Operator.DIVIDE && rightLow < 0 && rightHigh > 0) { + else if (node.operator() == Arithmetic.Operator.DIVIDE && rightLow < 0 && rightHigh > 0) { result.setLowValue(Double.NEGATIVE_INFINITY) .setHighValue(Double.POSITIVE_INFINITY); } - else if (node.getOperator() == ArithmeticBinaryExpression.Operator.MODULUS) { + else if (node.operator() == Arithmetic.Operator.MODULUS) { double maxDivisor = max(abs(rightLow), abs(rightHigh)); if (leftHigh <= 0) { result.setLowValue(max(-maxDivisor, leftLow)) @@ -228,10 +228,10 @@ else if (leftLow >= 0) { } } else { - double v1 = operate(node.getOperator(), leftLow, rightLow); - double v2 = operate(node.getOperator(), leftLow, rightHigh); - double v3 = operate(node.getOperator(), leftHigh, rightLow); - double v4 = operate(node.getOperator(), leftHigh, rightHigh); + double v1 = operate(node.operator(), leftLow, rightLow); + double v2 = operate(node.operator(), leftLow, rightHigh); + double v3 = operate(node.operator(), leftHigh, rightLow); + double v4 = operate(node.operator(), leftHigh, rightHigh); double lowValue = min(v1, v2, v3, v4); double highValue = max(v1, v2, v3, v4); @@ -242,7 +242,7 @@ else if (leftLow >= 0) { return result.build(); } - private double operate(ArithmeticBinaryExpression.Operator operator, double left, double right) + private double operate(Arithmetic.Operator operator, double left, double right) { switch (operator) { case ADD: @@ -260,11 +260,11 @@ private double operate(ArithmeticBinaryExpression.Operator operator, double left } @Override - protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) + protected SymbolStatsEstimate visitCoalesce(Coalesce node, Void context) { requireNonNull(node, "node is null"); SymbolStatsEstimate result = null; - for (Expression operand : node.getOperands()) { + for (Expression operand : node.operands()) { SymbolStatsEstimate operandEstimates = process(operand); if (result != null) { result = estimateCoalesce(result, operandEstimates); diff --git a/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java index 939396c9fddf..e99099725e3d 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java @@ -18,8 +18,8 @@ import io.trino.cost.StatsCalculator.Context; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.PlanNode; @@ -134,15 +134,15 @@ private Optional extractSemiJoinOutputFilter(Expression pr Expression remainingPredicate = combineConjuncts(conjuncts.stream() .filter(conjunct -> conjunct != semiJoinOutputReference) .collect(toImmutableList())); - boolean negated = semiJoinOutputReference instanceof NotExpression; + boolean negated = semiJoinOutputReference instanceof Not; return Optional.of(new SemiJoinOutputFilter(negated, remainingPredicate)); } private static boolean isSemiJoinOutputReference(Expression conjunct, Symbol semiJoinOutput) { - SymbolReference semiJoinOutputSymbolReference = semiJoinOutput.toSymbolReference(); + Reference semiJoinOutputSymbolReference = semiJoinOutput.toSymbolReference(); return conjunct.equals(semiJoinOutputSymbolReference) || - (conjunct instanceof NotExpression && ((NotExpression) conjunct).getValue().equals(semiJoinOutputSymbolReference)); + (conjunct instanceof Not && ((Not) conjunct).value().equals(semiJoinOutputSymbolReference)); } private static class SemiJoinOutputFilter diff --git a/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java index 9ee847f5f841..e3fec53c5186 100644 --- a/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java @@ -30,12 +30,12 @@ import io.trino.spi.type.BooleanType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.BuiltinFunctionCallBuilder; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.DynamicFilterId; @@ -51,9 +51,9 @@ import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.spi.type.StandardTypes.BOOLEAN; import static io.trino.spi.type.StandardTypes.VARCHAR; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static java.util.Objects.requireNonNull; @@ -66,7 +66,7 @@ public static Expression createDynamicFilterExpression( DynamicFilterId id, Type inputType, Expression input, - ComparisonExpression.Operator operator) + Comparison.Operator operator) { return createDynamicFilterExpression(metadata, id, inputType, input, operator, false); } @@ -76,7 +76,7 @@ public static Expression createDynamicFilterExpression( DynamicFilterId id, Type inputType, Expression input, - ComparisonExpression.Operator operator, + Comparison.Operator operator, boolean nullAllowed) { return BuiltinFunctionCallBuilder.resolve(metadata) @@ -84,7 +84,7 @@ public static Expression createDynamicFilterExpression( .addArgument(inputType, input) .addArgument(new Constant(VarcharType.VARCHAR, Slices.utf8Slice(operator.toString()))) .addArgument(new Constant(VarcharType.VARCHAR, Slices.utf8Slice(id.toString()))) - .addArgument(BooleanType.BOOLEAN, nullAllowed ? TRUE_LITERAL : FALSE_LITERAL) + .addArgument(BooleanType.BOOLEAN, nullAllowed ? TRUE : FALSE) .build(); } @@ -129,23 +129,23 @@ public static Multimap extractSourceSymbols(List getDescriptor(Expression expression) { - if (!(expression instanceof FunctionCall functionCall)) { + if (!(expression instanceof Call call)) { return Optional.empty(); } - if (!isDynamicFilterFunction(functionCall)) { + if (!isDynamicFilterFunction(call)) { return Optional.empty(); } - List arguments = functionCall.getArguments(); + List arguments = call.arguments(); checkArgument(arguments.size() == 4, "invalid arguments count: %s", arguments.size()); Expression probeSymbol = arguments.get(0); Expression operatorExpression = arguments.get(1); - checkArgument(operatorExpression instanceof Constant literal && literal.getType().equals(VarcharType.VARCHAR), "operatorExpression is expected to be a varchar: %s", operatorExpression.getClass().getSimpleName()); - String operatorExpressionString = ((Slice) ((Constant) operatorExpression).getValue()).toStringUtf8(); - ComparisonExpression.Operator operator = ComparisonExpression.Operator.valueOf(operatorExpressionString); + checkArgument(operatorExpression instanceof Constant literal && literal.type().equals(VarcharType.VARCHAR), "operatorExpression is expected to be a varchar: %s", operatorExpression.getClass().getSimpleName()); + String operatorExpressionString = ((Slice) ((Constant) operatorExpression).value()).toStringUtf8(); + Comparison.Operator operator = Comparison.Operator.valueOf(operatorExpressionString); Expression idExpression = arguments.get(2); - checkArgument(idExpression instanceof Constant literal && literal.getType().equals(VarcharType.VARCHAR), "id is expected to be a varchar: %s", idExpression.getClass().getSimpleName()); - String id = ((Slice) ((Constant) idExpression).getValue()).toStringUtf8(); + checkArgument(idExpression instanceof Constant literal && literal.type().equals(VarcharType.VARCHAR), "id is expected to be a varchar: %s", idExpression.getClass().getSimpleName()); + String id = ((Slice) ((Constant) idExpression).value()).toStringUtf8(); Expression nullAllowedExpression = arguments.get(3); - checkArgument(nullAllowedExpression instanceof Constant literal && literal.getType().equals(BooleanType.BOOLEAN), "nullAllowedExpression is expected to be a boolean constant: %s", nullAllowedExpression.getClass().getSimpleName()); - boolean nullAllowed = (boolean) ((Constant) nullAllowedExpression).getValue(); + checkArgument(nullAllowedExpression instanceof Constant literal && literal.type().equals(BooleanType.BOOLEAN), "nullAllowedExpression is expected to be a boolean constant: %s", nullAllowedExpression.getClass().getSimpleName()); + boolean nullAllowed = (boolean) ((Constant) nullAllowedExpression).value(); return Optional.of(new Descriptor(new DynamicFilterId(id), probeSymbol, operator, nullAllowed)); } - private static boolean isDynamicFilterFunction(FunctionCall functionCall) + private static boolean isDynamicFilterFunction(Call call) { - return isDynamicFilterFunction(functionCall.getFunction().getName()); + return isDynamicFilterFunction(call.function().getName()); } public static boolean isDynamicFilterFunction(CatalogSchemaFunctionName functionName) @@ -219,10 +219,10 @@ public static final class Descriptor { private final DynamicFilterId id; private final Expression input; - private final ComparisonExpression.Operator operator; + private final Comparison.Operator operator; private final boolean nullAllowed; - public Descriptor(DynamicFilterId id, Expression input, ComparisonExpression.Operator operator, boolean nullAllowed) + public Descriptor(DynamicFilterId id, Expression input, Comparison.Operator operator, boolean nullAllowed) { this.id = requireNonNull(id, "id is null"); this.input = requireNonNull(input, "input is null"); @@ -231,7 +231,7 @@ public Descriptor(DynamicFilterId id, Expression input, ComparisonExpression.Ope this.nullAllowed = nullAllowed; } - public Descriptor(DynamicFilterId id, Expression input, ComparisonExpression.Operator operator) + public Descriptor(DynamicFilterId id, Expression input, Comparison.Operator operator) { this(id, input, operator, false); } @@ -251,7 +251,7 @@ public Expression getInput() return input; } - public ComparisonExpression.Operator getOperator() + public Comparison.Operator getOperator() { return operator; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ArithmeticBinaryExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java similarity index 74% rename from core/trino-main/src/main/java/io/trino/sql/ir/ArithmeticBinaryExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java index c1fdefd3d32c..92dd63fe5142 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ArithmeticBinaryExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java @@ -23,7 +23,7 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record ArithmeticBinaryExpression(ResolvedFunction function, Operator operator, Expression left, Expression right) +public record Arithmetic(ResolvedFunction function, Operator operator, Expression left, Expression right) implements Expression { public enum Operator @@ -46,7 +46,7 @@ public String getValue() } } - public ArithmeticBinaryExpression + public Arithmetic { requireNonNull(function, "function is null"); requireNonNull(operator, "operator is null"); @@ -60,38 +60,14 @@ public Type type() return function.getSignature().getReturnType(); } - @Deprecated - public ResolvedFunction getFunction() - { - return function; - } - - @Deprecated - public Operator getOperator() - { - return operator; - } - - @Deprecated - public Expression getLeft() - { - return left; - } - - @Deprecated - public Expression getRight() - { - return right; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitArithmeticBinary(this, context); + return visitor.visitArithmetic(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(left, right); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/BetweenPredicate.java b/core/trino-main/src/main/java/io/trino/sql/ir/Between.java similarity index 76% rename from core/trino-main/src/main/java/io/trino/sql/ir/BetweenPredicate.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Between.java index 0591502c4b7e..ef661d2717ab 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/BetweenPredicate.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Between.java @@ -24,11 +24,11 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record BetweenPredicate(Expression value, Expression min, Expression max) +public record Between(Expression value, Expression min, Expression max) implements Expression { @JsonCreator - public BetweenPredicate + public Between { requireNonNull(value, "value is null"); requireNonNull(min, "min is null"); @@ -41,32 +41,14 @@ public Type type() return BOOLEAN; } - @Deprecated - public Expression getValue() - { - return value; - } - - @Deprecated - public Expression getMin() - { - return min; - } - - @Deprecated - public Expression getMax() - { - return max; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitBetweenPredicate(this, context); + return visitor.visitBetween(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(value, min, max); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/BindExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Bind.java similarity index 87% rename from core/trino-main/src/main/java/io/trino/sql/ir/BindExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Bind.java index 76340a56dfae..eef6aedfcf62 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/BindExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Bind.java @@ -49,10 +49,10 @@ * This expression facilitates desugaring. */ @JsonSerialize -public record BindExpression(List values, LambdaExpression function) +public record Bind(List values, Lambda function) implements Expression { - public BindExpression + public Bind { requireNonNull(function, "function is null"); values = ImmutableList.copyOf(values); @@ -69,26 +69,14 @@ public Type type() ((FunctionType) function.type()).getReturnType()); } - @Deprecated - public List getValues() - { - return values; - } - - @Deprecated - public Expression getFunction() - { - return function; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitBindExpression(this, context); + return visitor.visitBind(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.builder() .addAll(values) diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/BooleanLiteral.java b/core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java similarity index 74% rename from core/trino-main/src/main/java/io/trino/sql/ir/BooleanLiteral.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java index 1c7802de037d..80e90aa0707f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/BooleanLiteral.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java @@ -15,10 +15,10 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; -public final class BooleanLiteral +public final class Booleans { - public static final Constant TRUE_LITERAL = new Constant(BOOLEAN, true); - public static final Constant FALSE_LITERAL = new Constant(BOOLEAN, false); + public static final Constant TRUE = new Constant(BOOLEAN, true); + public static final Constant FALSE = new Constant(BOOLEAN, false); - private BooleanLiteral() {} + private Booleans() {} } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java b/core/trino-main/src/main/java/io/trino/sql/ir/Call.java similarity index 79% rename from core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Call.java index 0fa1ed410118..1905a1bd68da 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Call.java @@ -24,10 +24,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record FunctionCall(ResolvedFunction function, List arguments) +public record Call(ResolvedFunction function, List arguments) implements Expression { - public FunctionCall + public Call { requireNonNull(function, "function is null"); arguments = ImmutableList.copyOf(arguments); @@ -39,26 +39,14 @@ public Type type() return function.getSignature().getReturnType(); } - @Deprecated - public ResolvedFunction getFunction() - { - return function; - } - - @Deprecated - public List getArguments() - { - return arguments; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitFunctionCall(this, context); + return visitor.visitCall(this, context); } @Override - public List getChildren() + public List children() { return arguments; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/SearchedCaseExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Case.java similarity index 78% rename from core/trino-main/src/main/java/io/trino/sql/ir/SearchedCaseExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Case.java index 20cc5731adab..02a3621f1cb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/SearchedCaseExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Case.java @@ -24,10 +24,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record SearchedCaseExpression(List whenClauses, Optional defaultValue) +public record Case(List whenClauses, Optional defaultValue) implements Expression { - public SearchedCaseExpression + public Case { whenClauses = ImmutableList.copyOf(whenClauses); requireNonNull(defaultValue, "defaultValue is null"); @@ -39,26 +39,14 @@ public Type type() return whenClauses.getFirst().getResult().type(); } - @Deprecated - public List getWhenClauses() - { - return whenClauses; - } - - @Deprecated - public Optional getDefaultValue() - { - return defaultValue; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitSearchedCaseExpression(this, context); + return visitor.visitCase(this, context); } @Override - public List getChildren() + public List children() { ImmutableList.Builder builder = ImmutableList.builder(); whenClauses.forEach(clause -> { @@ -73,7 +61,7 @@ public List getChildren() @Override public String toString() { - return "SearchedCase(%s, %s)".formatted( + return "Case(%s, %s)".formatted( whenClauses.stream() .map(WhenClause::toString) .collect(Collectors.joining(", ")), diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Cast.java b/core/trino-main/src/main/java/io/trino/sql/ir/Cast.java index 07cb54f92483..8de9188e57aa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Cast.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Cast.java @@ -42,24 +42,6 @@ public Type type() return type; } - @Deprecated - public Expression getExpression() - { - return expression; - } - - @Deprecated - public Type getType() - { - return type; - } - - @Deprecated - public boolean isSafe() - { - return safe; - } - @Override public R accept(IrVisitor visitor, C context) { @@ -67,7 +49,7 @@ public R accept(IrVisitor visitor, C context) } @Override - public List getChildren() + public List children() { return ImmutableList.of(expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/CoalesceExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Coalesce.java similarity index 78% rename from core/trino-main/src/main/java/io/trino/sql/ir/CoalesceExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Coalesce.java index d12b3b64214c..ddea1ca908b6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/CoalesceExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Coalesce.java @@ -22,10 +22,10 @@ import static com.google.common.base.Preconditions.checkArgument; @JsonSerialize -public record CoalesceExpression(List operands) +public record Coalesce(List operands) implements Expression { - public CoalesceExpression(Expression first, Expression second, Expression... additional) + public Coalesce(Expression first, Expression second, Expression... additional) { this(ImmutableList.builder() .add(first, second) @@ -39,26 +39,20 @@ public Type type() return operands.getFirst().type(); } - public CoalesceExpression + public Coalesce { checkArgument(operands.size() >= 2, "must have at least two operands"); operands = ImmutableList.copyOf(operands); } - @Deprecated - public List getOperands() - { - return operands; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitCoalesceExpression(this, context); + return visitor.visitCoalesce(this, context); } @Override - public List getChildren() + public List children() { return operands; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ComparisonExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Comparison.java similarity index 86% rename from core/trino-main/src/main/java/io/trino/sql/ir/ComparisonExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Comparison.java index 3b316a3acff0..fb9bc4bab028 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ComparisonExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Comparison.java @@ -23,7 +23,7 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record ComparisonExpression(Operator operator, Expression left, Expression right) +public record Comparison(Operator operator, Expression left, Expression right) implements Expression { public enum Operator @@ -84,7 +84,7 @@ public Operator negate() } } - public ComparisonExpression + public Comparison { requireNonNull(operator, "operator is null"); requireNonNull(left, "left is null"); @@ -97,32 +97,14 @@ public Type type() return BOOLEAN; } - @Deprecated - public Operator getOperator() - { - return operator; - } - - @Deprecated - public Expression getLeft() - { - return left; - } - - @Deprecated - public Expression getRight() - { - return right; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitComparisonExpression(this, context); + return visitor.visitComparison(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(left, right); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Constant.java b/core/trino-main/src/main/java/io/trino/sql/ir/Constant.java index d1b87c9e3f4d..ed161641a434 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Constant.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Constant.java @@ -47,13 +47,6 @@ public static Constant fromJson( } } - @Deprecated - @JsonProperty - public Type getType() - { - return type; - } - @JsonProperty public Block getValueAsBlock() { @@ -62,11 +55,6 @@ public Block getValueAsBlock() return blockBuilder.build(); } - public Object getValue() - { - return value; - } - @Override public R accept(IrVisitor visitor, C context) { @@ -74,7 +62,7 @@ public R accept(IrVisitor visitor, C context) } @Override - public List getChildren() + public List children() { return ImmutableList.of(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java index 21a0a32c633c..4b8b20d4fc02 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java @@ -19,33 +19,33 @@ public abstract class DefaultTraversalVisitor @Override protected Void visitCast(Cast node, C context) { - process(node.getExpression(), context); + process(node.expression(), context); return null; } @Override - protected Void visitArithmeticBinary(ArithmeticBinaryExpression node, C context) + protected Void visitArithmetic(Arithmetic node, C context) { - process(node.getLeft(), context); - process(node.getRight(), context); + process(node.left(), context); + process(node.right(), context); return null; } @Override - protected Void visitBetweenPredicate(BetweenPredicate node, C context) + protected Void visitBetween(Between node, C context) { - process(node.getValue(), context); - process(node.getMin(), context); - process(node.getMax(), context); + process(node.value(), context); + process(node.min(), context); + process(node.max(), context); return null; } @Override - protected Void visitCoalesceExpression(CoalesceExpression node, C context) + protected Void visitCoalesce(Coalesce node, C context) { - for (Expression operand : node.getOperands()) { + for (Expression operand : node.operands()) { process(operand, context); } @@ -53,28 +53,28 @@ protected Void visitCoalesceExpression(CoalesceExpression node, C context) } @Override - protected Void visitSubscriptExpression(SubscriptExpression node, C context) + protected Void visitSubscript(Subscript node, C context) { - process(node.getBase(), context); - process(node.getIndex(), context); + process(node.base(), context); + process(node.index(), context); return null; } @Override - protected Void visitComparisonExpression(ComparisonExpression node, C context) + protected Void visitComparison(Comparison node, C context) { - process(node.getLeft(), context); - process(node.getRight(), context); + process(node.left(), context); + process(node.right(), context); return null; } @Override - protected Void visitInPredicate(InPredicate node, C context) + protected Void visitIn(In node, C context) { - process(node.getValue(), context); - for (Expression argument : node.getValueList()) { + process(node.value(), context); + for (Expression argument : node.valueList()) { process(argument, context); } @@ -82,9 +82,9 @@ protected Void visitInPredicate(InPredicate node, C context) } @Override - protected Void visitFunctionCall(FunctionCall node, C context) + protected Void visitCall(Call node, C context) { - for (Expression argument : node.getArguments()) { + for (Expression argument : node.arguments()) { process(argument, context); } @@ -92,78 +92,78 @@ protected Void visitFunctionCall(FunctionCall node, C context) } @Override - protected Void visitSimpleCaseExpression(SimpleCaseExpression node, C context) + protected Void visitSwitch(Switch node, C context) { - process(node.getOperand(), context); - for (WhenClause clause : node.getWhenClauses()) { + process(node.operand(), context); + for (WhenClause clause : node.whenClauses()) { process(clause.getOperand(), context); process(clause.getResult(), context); } - node.getDefaultValue() + node.defaultValue() .ifPresent(value -> process(value, context)); return null; } @Override - protected Void visitNullIfExpression(NullIfExpression node, C context) + protected Void visitNullIf(NullIf node, C context) { - process(node.getFirst(), context); - process(node.getSecond(), context); + process(node.first(), context); + process(node.second(), context); return null; } @Override - protected Void visitBindExpression(BindExpression node, C context) + protected Void visitBind(Bind node, C context) { - for (Expression value : node.getValues()) { + for (Expression value : node.values()) { process(value, context); } - process(node.getFunction(), context); + process(node.function(), context); return null; } @Override - protected Void visitArithmeticNegation(ArithmeticNegation node, C context) + protected Void visitNegation(Negation node, C context) { - process(node.getValue(), context); + process(node.value(), context); return null; } @Override - protected Void visitNotExpression(NotExpression node, C context) + protected Void visitNot(Not node, C context) { - process(node.getValue(), context); + process(node.value(), context); return null; } @Override - protected Void visitSearchedCaseExpression(SearchedCaseExpression node, C context) + protected Void visitCase(Case node, C context) { - for (WhenClause clause : node.getWhenClauses()) { + for (WhenClause clause : node.whenClauses()) { process(clause.getOperand(), context); process(clause.getResult(), context); } - node.getDefaultValue() + node.defaultValue() .ifPresent(value -> process(value, context)); return null; } @Override - protected Void visitIsNullPredicate(IsNullPredicate node, C context) + protected Void visitIsNull(IsNull node, C context) { - process(node.getValue(), context); + process(node.value(), context); return null; } @Override - protected Void visitLogicalExpression(LogicalExpression node, C context) + protected Void visitLogical(Logical node, C context) { - for (Expression child : node.getTerms()) { + for (Expression child : node.terms()) { process(child, context); } @@ -173,16 +173,16 @@ protected Void visitLogicalExpression(LogicalExpression node, C context) @Override protected Void visitRow(Row node, C context) { - for (Expression expression : node.getItems()) { + for (Expression expression : node.items()) { process(expression, context); } return null; } @Override - protected Void visitLambdaExpression(LambdaExpression node, C context) + protected Void visitLambda(Lambda node, C context) { - process(node.getBody(), context); + process(node.body(), context); return null; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java index fcc28a80e624..4c4860f44b69 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java @@ -24,33 +24,31 @@ @Immutable @JsonTypeInfo(use = JsonTypeInfo.Id.NAME) @JsonSubTypes({ - @JsonSubTypes.Type(value = ArithmeticBinaryExpression.class, name = "arithmeticBinary"), - @JsonSubTypes.Type(value = ArithmeticNegation.class, name = "arithmeticUnary"), - @JsonSubTypes.Type(value = BetweenPredicate.class, name = "between"), - @JsonSubTypes.Type(value = BindExpression.class, name = "bind"), + @JsonSubTypes.Type(value = Arithmetic.class, name = "arithmetic"), + @JsonSubTypes.Type(value = Negation.class, name = "negation"), + @JsonSubTypes.Type(value = Between.class, name = "between"), + @JsonSubTypes.Type(value = Bind.class, name = "bind"), @JsonSubTypes.Type(value = Cast.class, name = "cast"), - @JsonSubTypes.Type(value = CoalesceExpression.class, name = "coalesce"), - @JsonSubTypes.Type(value = ComparisonExpression.class, name = "comparison"), - @JsonSubTypes.Type(value = FunctionCall.class, name = "call"), + @JsonSubTypes.Type(value = Coalesce.class, name = "coalesce"), + @JsonSubTypes.Type(value = Comparison.class, name = "comparison"), + @JsonSubTypes.Type(value = Call.class, name = "call"), @JsonSubTypes.Type(value = Constant.class, name = "constant"), - @JsonSubTypes.Type(value = InPredicate.class, name = "in"), - @JsonSubTypes.Type(value = IsNullPredicate.class, name = "isNull"), - @JsonSubTypes.Type(value = LambdaExpression.class, name = "lambda"), - @JsonSubTypes.Type(value = LogicalExpression.class, name = "logicalBinary"), - @JsonSubTypes.Type(value = NotExpression.class, name = "not"), - @JsonSubTypes.Type(value = NullIfExpression.class, name = "nullif"), + @JsonSubTypes.Type(value = In.class, name = "in"), + @JsonSubTypes.Type(value = IsNull.class, name = "isnull"), + @JsonSubTypes.Type(value = Lambda.class, name = "lambda"), + @JsonSubTypes.Type(value = Logical.class, name = "logical"), + @JsonSubTypes.Type(value = Not.class, name = "not"), + @JsonSubTypes.Type(value = NullIf.class, name = "nullif"), @JsonSubTypes.Type(value = Row.class, name = "row"), - @JsonSubTypes.Type(value = SearchedCaseExpression.class, name = "searchedCase"), - @JsonSubTypes.Type(value = SimpleCaseExpression.class, name = "simpleCase"), - @JsonSubTypes.Type(value = SubscriptExpression.class, name = "subscript"), - @JsonSubTypes.Type(value = SymbolReference.class, name = "symbol"), + @JsonSubTypes.Type(value = Case.class, name = "case"), + @JsonSubTypes.Type(value = Switch.class, name = "switch"), + @JsonSubTypes.Type(value = Subscript.class, name = "subscript"), + @JsonSubTypes.Type(value = Reference.class, name = "reference"), }) public sealed interface Expression - permits ArithmeticBinaryExpression, ArithmeticNegation, BetweenPredicate, - BindExpression, Cast, CoalesceExpression, ComparisonExpression, FunctionCall, InPredicate, - IsNullPredicate, LambdaExpression, Constant, LogicalExpression, - NotExpression, NullIfExpression, Row, SearchedCaseExpression, SimpleCaseExpression, - SubscriptExpression, SymbolReference + permits Arithmetic, Between, Bind, Call, Case, Cast, Coalesce, + Comparison, Constant, In, IsNull, Lambda, Logical, Negation, + Not, NullIf, Reference, Row, Subscript, Switch { Type type(); @@ -63,5 +61,5 @@ default R accept(IrVisitor visitor, C context) } @JsonIgnore - List getChildren(); + List children(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java index 84c05fef98b4..d501cce26119 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java @@ -37,11 +37,11 @@ public static class Formatter extends IrVisitor { private final Optional> literalFormatter; - private final Optional> symbolReferenceFormatter; + private final Optional> symbolReferenceFormatter; public Formatter( Optional> literalFormatter, - Optional> symbolReferenceFormatter) + Optional> symbolReferenceFormatter) { this.literalFormatter = requireNonNull(literalFormatter, "literalFormatter is null"); this.symbolReferenceFormatter = requireNonNull(symbolReferenceFormatter, "symbolReferenceFormatter is null"); @@ -50,7 +50,7 @@ public Formatter( @Override protected String visitRow(Row node, Void context) { - return node.getItems().stream() + return node.items().stream() .map(child -> process(child, context)) .collect(joining(", ", "ROW (", ")")); } @@ -62,9 +62,9 @@ protected String visitExpression(Expression node, Void context) } @Override - protected String visitSubscriptExpression(SubscriptExpression node, Void context) + protected String visitSubscript(Subscript node, Void context) { - return formatExpression(node.getBase()) + "[" + formatExpression(node.getIndex()) + "]"; + return formatExpression(node.base()) + "[" + formatExpression(node.index()) + "]"; } @Override @@ -73,125 +73,125 @@ protected String visitConstant(Constant node, Void context) return literalFormatter .map(formatter -> formatter.apply(node)) .orElseGet(() -> { - if (node.getValue() == null) { - return "null::" + node.getType(); + if (node.value() == null) { + return "null::" + node.type(); } else { - return node.getType() + " '" + node.getType().getObjectValue(null, node.getValueAsBlock(), 0) + "'"; + return node.type() + " '" + node.type().getObjectValue(null, node.getValueAsBlock(), 0) + "'"; } }); } @Override - protected String visitFunctionCall(FunctionCall node, Void context) + protected String visitCall(Call node, Void context) { - return node.getFunction().getName().toString() + '(' + joinExpressions(node.getArguments()) + ')'; + return node.function().getName().toString() + '(' + joinExpressions(node.arguments()) + ')'; } @Override - protected String visitLambdaExpression(LambdaExpression node, Void context) + protected String visitLambda(Lambda node, Void context) { return "(" + node.arguments().stream() .map(Symbol::getName) .collect(joining(", ")) + ") -> " + - process(node.getBody(), context); + process(node.body(), context); } @Override - protected String visitSymbolReference(SymbolReference node, Void context) + protected String visitReference(Reference node, Void context) { if (symbolReferenceFormatter.isPresent()) { return symbolReferenceFormatter.get().apply(node); } - return node.getName(); + return node.name(); } @Override - protected String visitBindExpression(BindExpression node, Void context) + protected String visitBind(Bind node, Void context) { StringBuilder builder = new StringBuilder(); builder.append("\"$bind\"("); - for (Expression value : node.getValues()) { + for (Expression value : node.values()) { builder.append(process(value, context)) .append(", "); } - builder.append(process(node.getFunction(), context)) + builder.append(process(node.function(), context)) .append(")"); return builder.toString(); } @Override - protected String visitLogicalExpression(LogicalExpression node, Void context) + protected String visitLogical(Logical node, Void context) { return "(" + - node.getTerms().stream() + node.terms().stream() .map(term -> process(term, context)) - .collect(joining(" " + node.getOperator().toString() + " ")) + + .collect(joining(" " + node.operator().toString() + " ")) + ")"; } @Override - protected String visitNotExpression(NotExpression node, Void context) + protected String visitNot(Not node, Void context) { - return "(NOT " + process(node.getValue(), context) + ")"; + return "(NOT " + process(node.value(), context) + ")"; } @Override - protected String visitComparisonExpression(ComparisonExpression node, Void context) + protected String visitComparison(Comparison node, Void context) { - return formatBinaryExpression(node.getOperator().getValue(), node.getLeft(), node.getRight()); + return formatBinaryExpression(node.operator().getValue(), node.left(), node.right()); } @Override - protected String visitIsNullPredicate(IsNullPredicate node, Void context) + protected String visitIsNull(IsNull node, Void context) { - return "(" + process(node.getValue(), context) + " IS NULL)"; + return "(" + process(node.value(), context) + " IS NULL)"; } @Override - protected String visitNullIfExpression(NullIfExpression node, Void context) + protected String visitNullIf(NullIf node, Void context) { - return "NULLIF(" + process(node.getFirst(), context) + ", " + process(node.getSecond(), context) + ')'; + return "NULLIF(" + process(node.first(), context) + ", " + process(node.second(), context) + ')'; } @Override - protected String visitCoalesceExpression(CoalesceExpression node, Void context) + protected String visitCoalesce(Coalesce node, Void context) { - return "COALESCE(" + joinExpressions(node.getOperands()) + ")"; + return "COALESCE(" + joinExpressions(node.operands()) + ")"; } @Override - protected String visitArithmeticNegation(ArithmeticNegation node, Void context) + protected String visitNegation(Negation node, Void context) { - return "-(" + process(node.getValue(), context) + ")"; + return "-(" + process(node.value(), context) + ")"; } @Override - protected String visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + protected String visitArithmetic(Arithmetic node, Void context) { - return formatBinaryExpression(node.getOperator().getValue(), node.getLeft(), node.getRight()); + return formatBinaryExpression(node.operator().getValue(), node.left(), node.right()); } @Override public String visitCast(Cast node, Void context) { - return (node.isSafe() ? "TRY_CAST" : "CAST") + - "(" + process(node.getExpression(), context) + " AS " + node.getType().getDisplayName() + ")"; + return (node.safe() ? "TRY_CAST" : "CAST") + + "(" + process(node.expression(), context) + " AS " + node.type().getDisplayName() + ")"; } @Override - protected String visitSearchedCaseExpression(SearchedCaseExpression node, Void context) + protected String visitCase(Case node, Void context) { ImmutableList.Builder parts = ImmutableList.builder(); parts.add("CASE"); - for (WhenClause whenClause : node.getWhenClauses()) { + for (WhenClause whenClause : node.whenClauses()) { parts.add(format(whenClause, context)); } - node.getDefaultValue() + node.defaultValue() .ifPresent(value -> parts.add("ELSE").add(process(value, context))); parts.add("END"); @@ -200,18 +200,18 @@ protected String visitSearchedCaseExpression(SearchedCaseExpression node, Void c } @Override - protected String visitSimpleCaseExpression(SimpleCaseExpression node, Void context) + protected String visitSwitch(Switch node, Void context) { ImmutableList.Builder parts = ImmutableList.builder(); parts.add("CASE") - .add(process(node.getOperand(), context)); + .add(process(node.operand(), context)); - for (WhenClause whenClause : node.getWhenClauses()) { + for (WhenClause whenClause : node.whenClauses()) { parts.add(format(whenClause, context)); } - node.getDefaultValue() + node.defaultValue() .ifPresent(value -> parts.add("ELSE").add(process(value, context))); parts.add("END"); @@ -225,16 +225,16 @@ protected String format(WhenClause node, Void context) } @Override - protected String visitBetweenPredicate(BetweenPredicate node, Void context) + protected String visitBetween(Between node, Void context) { - return "(" + process(node.getValue(), context) + " BETWEEN " + - process(node.getMin(), context) + " AND " + process(node.getMax(), context) + ")"; + return "(" + process(node.value(), context) + " BETWEEN " + + process(node.min(), context) + " AND " + process(node.max(), context) + ")"; } @Override - protected String visitInPredicate(InPredicate node, Void context) + protected String visitIn(In node, Void context) { - return "(" + process(node.getValue(), context) + " IN " + joinExpressions(node.getValueList()) + ")"; + return "(" + process(node.value(), context) + " IN " + joinExpressions(node.valueList()) + ")"; } private String formatBinaryExpression(String operator, Expression left, Expression right) diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java index fd36acc23eb1..282c70525c38 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java @@ -25,77 +25,77 @@ public Expression rewriteRow(Row node, C context, ExpressionTreeRewriter tree return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteArithmeticUnary(ArithmeticNegation node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteNegation(Negation node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteArithmeticBinary(ArithmeticBinaryExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteArithmetic(Arithmetic node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteComparisonExpression(ComparisonExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteComparison(Comparison node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteBetweenPredicate(BetweenPredicate node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteBetween(Between node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteLogicalExpression(LogicalExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLogical(Logical node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteNotExpression(NotExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteNot(Not node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteIsNullPredicate(IsNullPredicate node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteIsNull(IsNull node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteNullIfExpression(NullIfExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteNullIf(NullIf node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteSearchedCaseExpression(SearchedCaseExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCase(Case node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteSimpleCaseExpression(SimpleCaseExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteSwitch(Switch node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteCoalesceExpression(CoalesceExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCoalesce(Coalesce node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteFunctionCall(FunctionCall node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCall(Call node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteLambdaExpression(LambdaExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLambda(Lambda node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteBindExpression(BindExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteBind(Bind node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteInPredicate(InPredicate node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteIn(In node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } @@ -105,7 +105,7 @@ public Expression rewriteConstant(Constant node, C context, ExpressionTreeRewrit return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteSubscriptExpression(SubscriptExpression node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteSubscript(Subscript node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } @@ -115,7 +115,7 @@ public Expression rewriteCast(Cast node, C context, ExpressionTreeRewriter tr return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteSymbolReference(SymbolReference node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteReference(Reference node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java index 82b774f00a1c..28f0f3fcee06 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java @@ -88,9 +88,9 @@ protected Expression visitRow(Row node, Context context) } } - List items = rewrite(node.getItems(), context); + List items = rewrite(node.items(), context); - if (!sameElements(node.getItems(), items)) { + if (!sameElements(node.items(), items)) { return new Row(items); } @@ -98,229 +98,229 @@ protected Expression visitRow(Row node, Context context) } @Override - protected Expression visitArithmeticNegation(ArithmeticNegation node, Context context) + protected Expression visitNegation(Negation node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteArithmeticUnary(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteNegation(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression child = rewrite(node.getValue(), context.get()); - if (child != node.getValue()) { - return new ArithmeticNegation(child); + Expression child = rewrite(node.value(), context.get()); + if (child != node.value()) { + return new Negation(child); } return node; } @Override - public Expression visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) + public Expression visitArithmetic(Arithmetic node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteArithmeticBinary(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteArithmetic(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression left = rewrite(node.getLeft(), context.get()); - Expression right = rewrite(node.getRight(), context.get()); + Expression left = rewrite(node.left(), context.get()); + Expression right = rewrite(node.right(), context.get()); - if (left != node.getLeft() || right != node.getRight()) { - return new ArithmeticBinaryExpression(node.getFunction(), node.getOperator(), left, right); + if (left != node.left() || right != node.right()) { + return new Arithmetic(node.function(), node.operator(), left, right); } return node; } @Override - protected Expression visitSubscriptExpression(SubscriptExpression node, Context context) + protected Expression visitSubscript(Subscript node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteSubscriptExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteSubscript(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression base = rewrite(node.getBase(), context.get()); - Expression index = rewrite(node.getIndex(), context.get()); + Expression base = rewrite(node.base(), context.get()); + Expression index = rewrite(node.index(), context.get()); - if (base != node.getBase() || index != node.getIndex()) { - return new SubscriptExpression(node.type(), base, index); + if (base != node.base() || index != node.index()) { + return new Subscript(node.type(), base, index); } return node; } @Override - public Expression visitComparisonExpression(ComparisonExpression node, Context context) + public Expression visitComparison(Comparison node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteComparisonExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteComparison(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression left = rewrite(node.getLeft(), context.get()); - Expression right = rewrite(node.getRight(), context.get()); + Expression left = rewrite(node.left(), context.get()); + Expression right = rewrite(node.right(), context.get()); - if (left != node.getLeft() || right != node.getRight()) { - return new ComparisonExpression(node.getOperator(), left, right); + if (left != node.left() || right != node.right()) { + return new Comparison(node.operator(), left, right); } return node; } @Override - protected Expression visitBetweenPredicate(BetweenPredicate node, Context context) + protected Expression visitBetween(Between node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteBetweenPredicate(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteBetween(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression value = rewrite(node.getValue(), context.get()); - Expression min = rewrite(node.getMin(), context.get()); - Expression max = rewrite(node.getMax(), context.get()); + Expression value = rewrite(node.value(), context.get()); + Expression min = rewrite(node.min(), context.get()); + Expression max = rewrite(node.max(), context.get()); - if (value != node.getValue() || min != node.getMin() || max != node.getMax()) { - return new BetweenPredicate(value, min, max); + if (value != node.value() || min != node.min() || max != node.max()) { + return new Between(value, min, max); } return node; } @Override - public Expression visitLogicalExpression(LogicalExpression node, Context context) + public Expression visitLogical(Logical node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteLogicalExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteLogical(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - List terms = rewrite(node.getTerms(), context); - if (!sameElements(node.getTerms(), terms)) { - return new LogicalExpression(node.getOperator(), terms); + List terms = rewrite(node.terms(), context); + if (!sameElements(node.terms(), terms)) { + return new Logical(node.operator(), terms); } return node; } @Override - public Expression visitNotExpression(NotExpression node, Context context) + public Expression visitNot(Not node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteNotExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteNot(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression value = rewrite(node.getValue(), context.get()); + Expression value = rewrite(node.value(), context.get()); - if (value != node.getValue()) { - return new NotExpression(value); + if (value != node.value()) { + return new Not(value); } return node; } @Override - protected Expression visitIsNullPredicate(IsNullPredicate node, Context context) + protected Expression visitIsNull(IsNull node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteIsNullPredicate(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteIsNull(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression value = rewrite(node.getValue(), context.get()); + Expression value = rewrite(node.value(), context.get()); - if (value != node.getValue()) { - return new IsNullPredicate(value); + if (value != node.value()) { + return new IsNull(value); } return node; } @Override - protected Expression visitNullIfExpression(NullIfExpression node, Context context) + protected Expression visitNullIf(NullIf node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteNullIfExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteNullIf(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression first = rewrite(node.getFirst(), context.get()); - Expression second = rewrite(node.getSecond(), context.get()); + Expression first = rewrite(node.first(), context.get()); + Expression second = rewrite(node.second(), context.get()); - if (first != node.getFirst() || second != node.getSecond()) { - return new NullIfExpression(first, second); + if (first != node.first() || second != node.second()) { + return new NullIf(first, second); } return node; } @Override - protected Expression visitSearchedCaseExpression(SearchedCaseExpression node, Context context) + protected Expression visitCase(Case node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteSearchedCaseExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteCase(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } ImmutableList.Builder builder = ImmutableList.builder(); - for (WhenClause expression : node.getWhenClauses()) { + for (WhenClause expression : node.whenClauses()) { builder.add(rewriteWhenClause(expression, context)); } - Optional defaultValue = node.getDefaultValue() + Optional defaultValue = node.defaultValue() .map(value -> rewrite(value, context.get())); - if (!sameElements(node.getDefaultValue(), defaultValue) || !sameElements(node.getWhenClauses(), builder.build())) { - return new SearchedCaseExpression(builder.build(), defaultValue); + if (!sameElements(node.defaultValue(), defaultValue) || !sameElements(node.whenClauses(), builder.build())) { + return new Case(builder.build(), defaultValue); } return node; } @Override - protected Expression visitSimpleCaseExpression(SimpleCaseExpression node, Context context) + protected Expression visitSwitch(Switch node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteSimpleCaseExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteSwitch(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression operand = rewrite(node.getOperand(), context.get()); + Expression operand = rewrite(node.operand(), context.get()); ImmutableList.Builder builder = ImmutableList.builder(); - for (WhenClause expression : node.getWhenClauses()) { + for (WhenClause expression : node.whenClauses()) { builder.add(rewriteWhenClause(expression, context)); } - Optional defaultValue = node.getDefaultValue() + Optional defaultValue = node.defaultValue() .map(value -> rewrite(value, context.get())); - if (operand != node.getOperand() || - !sameElements(node.getDefaultValue(), defaultValue) || - !sameElements(node.getWhenClauses(), builder.build())) { - return new SimpleCaseExpression(operand, builder.build(), defaultValue); + if (operand != node.operand() || + !sameElements(node.defaultValue(), defaultValue) || + !sameElements(node.whenClauses(), builder.build())) { + return new Switch(operand, builder.build(), defaultValue); } return node; @@ -338,98 +338,98 @@ protected WhenClause rewriteWhenClause(WhenClause node, Context context) } @Override - protected Expression visitCoalesceExpression(CoalesceExpression node, Context context) + protected Expression visitCoalesce(Coalesce node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteCoalesceExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteCoalesce(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - List operands = rewrite(node.getOperands(), context); + List operands = rewrite(node.operands(), context); - if (!sameElements(node.getOperands(), operands)) { - return new CoalesceExpression(operands); + if (!sameElements(node.operands(), operands)) { + return new Coalesce(operands); } return node; } @Override - public Expression visitFunctionCall(FunctionCall node, Context context) + public Expression visitCall(Call node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteFunctionCall(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteCall(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - List arguments = rewrite(node.getArguments(), context); + List arguments = rewrite(node.arguments(), context); - if (!sameElements(node.getArguments(), arguments)) { - return new FunctionCall(node.getFunction(), arguments); + if (!sameElements(node.arguments(), arguments)) { + return new Call(node.function(), arguments); } return node; } @Override - protected Expression visitLambdaExpression(LambdaExpression node, Context context) + protected Expression visitLambda(Lambda node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteLambdaExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteLambda(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression body = rewrite(node.getBody(), context.get()); - if (body != node.getBody()) { - return new LambdaExpression(node.arguments(), body); + Expression body = rewrite(node.body(), context.get()); + if (body != node.body()) { + return new Lambda(node.arguments(), body); } return node; } @Override - protected Expression visitBindExpression(BindExpression node, Context context) + protected Expression visitBind(Bind node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteBindExpression(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteBind(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - List values = node.getValues().stream() + List values = node.values().stream() .map(value -> rewrite(value, context.get())) .collect(toImmutableList()); - Expression function = rewrite(node.getFunction(), context.get()); + Expression function = rewrite(node.function(), context.get()); - if (!sameElements(values, node.getValues()) || (function != node.getFunction())) { - return new BindExpression(values, (LambdaExpression) function); + if (!sameElements(values, node.values()) || (function != node.function())) { + return new Bind(values, (Lambda) function); } return node; } @Override - public Expression visitInPredicate(InPredicate node, Context context) + public Expression visitIn(In node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteInPredicate(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteIn(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } } - Expression value = rewrite(node.getValue(), context.get()); - List values = node.getValueList().stream() + Expression value = rewrite(node.value(), context.get()); + List values = node.valueList().stream() .map(entry -> rewrite(entry, context.get())) .collect(toImmutableList()); - if (node.getValue() != value || !sameElements(values, node.getValueList())) { - return new InPredicate(value, values); + if (node.value() != value || !sameElements(values, node.valueList())) { + return new In(value, values); } return node; @@ -458,20 +458,20 @@ public Expression visitCast(Cast node, Context context) } } - Expression expression = rewrite(node.getExpression(), context.get()); + Expression expression = rewrite(node.expression(), context.get()); - if (node.getExpression() != expression) { - return new Cast(expression, node.getType(), node.isSafe()); + if (node.expression() != expression) { + return new Cast(expression, node.type(), node.safe()); } return node; } @Override - protected Expression visitSymbolReference(SymbolReference node, Context context) + protected Expression visitReference(Reference node, Context context) { if (!context.isDefaultRewrite()) { - Expression result = rewriter.rewriteSymbolReference(node, context.get(), ExpressionTreeRewriter.this); + Expression result = rewriter.rewriteReference(node, context.get(), ExpressionTreeRewriter.this); if (result != null) { return result; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/InPredicate.java b/core/trino-main/src/main/java/io/trino/sql/ir/In.java similarity index 76% rename from core/trino-main/src/main/java/io/trino/sql/ir/InPredicate.java rename to core/trino-main/src/main/java/io/trino/sql/ir/In.java index 3dbb863f8e2b..8daf559f8eb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/InPredicate.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/In.java @@ -22,10 +22,10 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; @JsonSerialize -public record InPredicate(Expression value, List valueList) +public record In(Expression value, List valueList) implements Expression { - public InPredicate + public In { valueList = ImmutableList.copyOf(valueList); } @@ -36,26 +36,14 @@ public Type type() return BOOLEAN; } - @Deprecated - public Expression getValue() - { - return value; - } - - @Deprecated - public List getValueList() - { - return valueList; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitInPredicate(this, context); + return visitor.visitIn(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.builder() .add(value) diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java index d880d21c01dc..76409b1c356f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java @@ -23,11 +23,11 @@ private IrExpressions() {} public static Expression ifExpression(Expression condition, Expression trueCase) { - return new SearchedCaseExpression(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.empty()); + return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.empty()); } public static Expression ifExpression(Expression condition, Expression trueCase, Expression falseCase) { - return new SearchedCaseExpression(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.of(falseCase)); + return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.of(falseCase)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java index f6e1c1624e3b..6274da4abda6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java @@ -34,8 +34,8 @@ import static com.google.common.base.Predicates.not; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Streams.stream; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -45,30 +45,30 @@ private IrUtils() {} public static List extractConjuncts(Expression expression) { - return extractPredicates(LogicalExpression.Operator.AND, expression); + return extractPredicates(Logical.Operator.AND, expression); } public static List extractDisjuncts(Expression expression) { - return extractPredicates(LogicalExpression.Operator.OR, expression); + return extractPredicates(Logical.Operator.OR, expression); } - public static List extractPredicates(LogicalExpression expression) + public static List extractPredicates(Logical expression) { - return extractPredicates(expression.getOperator(), expression); + return extractPredicates(expression.operator(), expression); } - public static List extractPredicates(LogicalExpression.Operator operator, Expression expression) + public static List extractPredicates(Logical.Operator operator, Expression expression) { ImmutableList.Builder resultBuilder = ImmutableList.builder(); extractPredicates(operator, expression, resultBuilder); return resultBuilder.build(); } - private static void extractPredicates(LogicalExpression.Operator operator, Expression expression, ImmutableList.Builder resultBuilder) + private static void extractPredicates(Logical.Operator operator, Expression expression, ImmutableList.Builder resultBuilder) { - if (expression instanceof LogicalExpression logicalExpression && logicalExpression.getOperator() == operator) { - for (Expression term : logicalExpression.getTerms()) { + if (expression instanceof Logical logical && logical.operator() == operator) { + for (Expression term : logical.terms()) { extractPredicates(operator, term, resultBuilder); } } @@ -84,7 +84,7 @@ public static Expression and(Expression... expressions) public static Expression and(Collection expressions) { - return logicalExpression(LogicalExpression.Operator.AND, expressions); + return logicalExpression(Logical.Operator.AND, expressions); } public static Expression or(Expression... expressions) @@ -94,10 +94,10 @@ public static Expression or(Expression... expressions) public static Expression or(Collection expressions) { - return logicalExpression(LogicalExpression.Operator.OR, expressions); + return logicalExpression(Logical.Operator.OR, expressions); } - public static Expression logicalExpression(LogicalExpression.Operator operator, Collection expressions) + public static Expression logicalExpression(Logical.Operator operator, Collection expressions) { requireNonNull(operator, "operator is null"); requireNonNull(expressions, "expressions is null"); @@ -105,9 +105,9 @@ public static Expression logicalExpression(LogicalExpression.Operator operator, if (expressions.isEmpty()) { switch (operator) { case AND: - return TRUE_LITERAL; + return TRUE; case OR: - return FALSE_LITERAL; + return FALSE; } throw new IllegalArgumentException("Unsupported LogicalExpression operator"); } @@ -116,12 +116,12 @@ public static Expression logicalExpression(LogicalExpression.Operator operator, return Iterables.getOnlyElement(expressions); } - return new LogicalExpression(operator, ImmutableList.copyOf(expressions)); + return new Logical(operator, ImmutableList.copyOf(expressions)); } - public static Expression combinePredicates(LogicalExpression.Operator operator, Collection expressions) + public static Expression combinePredicates(Logical.Operator operator, Collection expressions) { - if (operator == LogicalExpression.Operator.AND) { + if (operator == Logical.Operator.AND) { return combineConjuncts(expressions); } @@ -139,13 +139,13 @@ public static Expression combineConjuncts(Collection expressions) List conjuncts = expressions.stream() .flatMap(e -> extractConjuncts(e).stream()) - .filter(e -> !e.equals(TRUE_LITERAL)) + .filter(e -> !e.equals(TRUE)) .collect(toList()); conjuncts = removeDuplicates(conjuncts); - if (conjuncts.contains(FALSE_LITERAL)) { - return FALSE_LITERAL; + if (conjuncts.contains(FALSE)) { + return FALSE; } return and(conjuncts); @@ -157,11 +157,11 @@ public static Expression combineConjunctsWithDuplicates(Collection e List conjuncts = expressions.stream() .flatMap(e -> extractConjuncts(e).stream()) - .filter(e -> !e.equals(TRUE_LITERAL)) + .filter(e -> !e.equals(TRUE)) .collect(toList()); - if (conjuncts.contains(FALSE_LITERAL)) { - return FALSE_LITERAL; + if (conjuncts.contains(FALSE)) { + return FALSE; } return and(conjuncts); @@ -174,7 +174,7 @@ public static Expression combineDisjuncts(Expression... expressions) public static Expression combineDisjuncts(Collection expressions) { - return combineDisjunctsWithDefault(expressions, FALSE_LITERAL); + return combineDisjunctsWithDefault(expressions, FALSE); } public static Expression combineDisjunctsWithDefault(Collection expressions, Expression emptyDefault) @@ -183,13 +183,13 @@ public static Expression combineDisjunctsWithDefault(Collection expr List disjuncts = expressions.stream() .flatMap(e -> extractDisjuncts(e).stream()) - .filter(e -> !e.equals(FALSE_LITERAL)) + .filter(e -> !e.equals(FALSE)) .collect(toList()); disjuncts = removeDuplicates(disjuncts); - if (disjuncts.contains(TRUE_LITERAL)) { - return TRUE_LITERAL; + if (disjuncts.contains(TRUE)) { + return TRUE; } return disjuncts.isEmpty() ? emptyDefault : or(disjuncts); @@ -232,7 +232,7 @@ public static Function expressionOrNullSymbols(Predicate ImmutableList.Builder nullConjuncts = ImmutableList.builder(); for (Symbol symbol : symbols) { - nullConjuncts.add(new IsNullPredicate(symbol.toSymbolReference())); + nullConjuncts.add(new IsNull(symbol.toSymbolReference())); } resultDisjunct.add(and(nullConjuncts.build())); @@ -267,7 +267,7 @@ else if (!seen.contains(expression)) { public static Stream preOrder(Expression node) { return stream( - Traverser.forTree((SuccessorsFunction) Expression::getChildren) + Traverser.forTree((SuccessorsFunction) Expression::children) .depthFirstPreOrder(requireNonNull(node, "node is null"))); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java index 33cfe72f851b..d6662ebc0127 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java @@ -32,22 +32,22 @@ protected R visitExpression(Expression node, C context) return null; } - protected R visitArithmeticBinary(ArithmeticBinaryExpression node, C context) + protected R visitArithmetic(Arithmetic node, C context) { return visitExpression(node, context); } - protected R visitBetweenPredicate(BetweenPredicate node, C context) + protected R visitBetween(Between node, C context) { return visitExpression(node, context); } - protected R visitCoalesceExpression(CoalesceExpression node, C context) + protected R visitCoalesce(Coalesce node, C context) { return visitExpression(node, context); } - protected R visitComparisonExpression(ComparisonExpression node, C context) + protected R visitComparison(Comparison node, C context) { return visitExpression(node, context); } @@ -57,57 +57,57 @@ protected R visitConstant(Constant node, C context) return visitExpression(node, context); } - protected R visitInPredicate(InPredicate node, C context) + protected R visitIn(In node, C context) { return visitExpression(node, context); } - protected R visitFunctionCall(FunctionCall node, C context) + protected R visitCall(Call node, C context) { return visitExpression(node, context); } - protected R visitLambdaExpression(LambdaExpression node, C context) + protected R visitLambda(Lambda node, C context) { return visitExpression(node, context); } - protected R visitSimpleCaseExpression(SimpleCaseExpression node, C context) + protected R visitSwitch(Switch node, C context) { return visitExpression(node, context); } - protected R visitNullIfExpression(NullIfExpression node, C context) + protected R visitNullIf(NullIf node, C context) { return visitExpression(node, context); } - protected R visitArithmeticNegation(ArithmeticNegation node, C context) + protected R visitNegation(Negation node, C context) { return visitExpression(node, context); } - protected R visitNotExpression(NotExpression node, C context) + protected R visitNot(Not node, C context) { return visitExpression(node, context); } - protected R visitSearchedCaseExpression(SearchedCaseExpression node, C context) + protected R visitCase(Case node, C context) { return visitExpression(node, context); } - protected R visitIsNullPredicate(IsNullPredicate node, C context) + protected R visitIsNull(IsNull node, C context) { return visitExpression(node, context); } - protected R visitSubscriptExpression(SubscriptExpression node, C context) + protected R visitSubscript(Subscript node, C context) { return visitExpression(node, context); } - protected R visitLogicalExpression(LogicalExpression node, C context) + protected R visitLogical(Logical node, C context) { return visitExpression(node, context); } @@ -122,12 +122,12 @@ protected R visitCast(Cast node, C context) return visitExpression(node, context); } - protected R visitSymbolReference(SymbolReference node, C context) + protected R visitReference(Reference node, C context) { return visitExpression(node, context); } - protected R visitBindExpression(BindExpression node, C context) + protected R visitBind(Bind node, C context) { return visitExpression(node, context); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/NotExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/IsNull.java similarity index 81% rename from core/trino-main/src/main/java/io/trino/sql/ir/NotExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/IsNull.java index 4ba207b408f6..6d1602a4d296 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/NotExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IsNull.java @@ -23,10 +23,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record NotExpression(Expression value) +public record IsNull(Expression value) implements Expression { - public NotExpression + public IsNull { requireNonNull(value, "value is null"); } @@ -37,20 +37,14 @@ public Type type() return BOOLEAN; } - @Deprecated - public Expression getValue() - { - return value; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitNotExpression(this, context); + return visitor.visitIsNull(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(value); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/LambdaExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Lambda.java similarity index 81% rename from core/trino-main/src/main/java/io/trino/sql/ir/LambdaExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Lambda.java index 2346cb646990..0023f7a5cbe8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/LambdaExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Lambda.java @@ -25,10 +25,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record LambdaExpression(List arguments, Expression body) +public record Lambda(List arguments, Expression body) implements Expression { - public LambdaExpression + public Lambda { requireNonNull(arguments, "arguments is null"); requireNonNull(body, "body is null"); @@ -42,26 +42,14 @@ public Type type() body.type()); } - @Deprecated - public List getArguments() - { - return arguments; - } - - @Deprecated - public Expression getBody() - { - return body; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitLambdaExpression(this, context); + return visitor.visitLambda(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(body); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/LogicalExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Logical.java similarity index 75% rename from core/trino-main/src/main/java/io/trino/sql/ir/LogicalExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Logical.java index 1d4a500182f9..25d647c8c8a7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/LogicalExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Logical.java @@ -25,7 +25,7 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record LogicalExpression(Operator operator, List terms) +public record Logical(Operator operator, List terms) implements Expression { public enum Operator @@ -44,7 +44,7 @@ public Operator flip() } } - public LogicalExpression + public Logical { requireNonNull(operator, "operator is null"); checkArgument(terms.size() >= 2, "Expected at least 2 terms"); @@ -57,38 +57,26 @@ public Type type() return BOOLEAN; } - @Deprecated - public Operator getOperator() - { - return operator; - } - - @Deprecated - public List getTerms() - { - return terms; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitLogicalExpression(this, context); + return visitor.visitLogical(this, context); } @Override - public List getChildren() + public List children() { return terms; } - public static LogicalExpression and(Expression left, Expression right) + public static Logical and(Expression left, Expression right) { - return new LogicalExpression(Operator.AND, ImmutableList.of(left, right)); + return new Logical(Operator.AND, ImmutableList.of(left, right)); } - public static LogicalExpression or(Expression left, Expression right) + public static Logical or(Expression left, Expression right) { - return new LogicalExpression(Operator.OR, ImmutableList.of(left, right)); + return new Logical(Operator.OR, ImmutableList.of(left, right)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ArithmeticNegation.java b/core/trino-main/src/main/java/io/trino/sql/ir/Negation.java similarity index 81% rename from core/trino-main/src/main/java/io/trino/sql/ir/ArithmeticNegation.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Negation.java index 6967c7e51a27..6bf3f7f9f45c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ArithmeticNegation.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Negation.java @@ -22,10 +22,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record ArithmeticNegation(Expression value) +public record Negation(Expression value) implements Expression { - public ArithmeticNegation + public Negation { requireNonNull(value, "value is null"); } @@ -36,20 +36,14 @@ public Type type() return value.type(); } - @Deprecated - public Expression getValue() - { - return value; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitArithmeticNegation(this, context); + return visitor.visitNegation(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(value); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IsNullPredicate.java b/core/trino-main/src/main/java/io/trino/sql/ir/Not.java similarity index 81% rename from core/trino-main/src/main/java/io/trino/sql/ir/IsNullPredicate.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Not.java index db1b8fce2587..1f5b3e1295f5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IsNullPredicate.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Not.java @@ -23,10 +23,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record IsNullPredicate(Expression value) +public record Not(Expression value) implements Expression { - public IsNullPredicate + public Not { requireNonNull(value, "value is null"); } @@ -37,20 +37,14 @@ public Type type() return BOOLEAN; } - @Deprecated - public Expression getValue() - { - return value; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitIsNullPredicate(this, context); + return visitor.visitNot(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(value); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/NullIfExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/NullIf.java similarity index 78% rename from core/trino-main/src/main/java/io/trino/sql/ir/NullIfExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/NullIf.java index 1c866b1eb48b..02da058677f7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/NullIfExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/NullIf.java @@ -25,10 +25,10 @@ * NULLIF(V1,V2): CASE WHEN V1=V2 THEN NULL ELSE V1 END */ @JsonSerialize -public record NullIfExpression(Expression first, Expression second) +public record NullIf(Expression first, Expression second) implements Expression { - public NullIfExpression + public NullIf { requireNonNull(first, "first is null"); requireNonNull(second, "second is null"); @@ -40,26 +40,14 @@ public Type type() return first.type(); } - @Deprecated - public Expression getFirst() - { - return first; - } - - @Deprecated - public Expression getSecond() - { - return second; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitNullIfExpression(this, context); + return visitor.visitNullIf(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(first, second); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/SymbolReference.java b/core/trino-main/src/main/java/io/trino/sql/ir/Reference.java similarity index 80% rename from core/trino-main/src/main/java/io/trino/sql/ir/SymbolReference.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Reference.java index bf4c7049f029..e5a70897b37b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/SymbolReference.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Reference.java @@ -22,10 +22,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record SymbolReference(Type type, String name) +public record Reference(Type type, String name) implements Expression { - public SymbolReference + public Reference { requireNonNull(name, "name is null"); } @@ -36,20 +36,14 @@ public Type type() return type; } - @Deprecated - public String getName() - { - return name; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitSymbolReference(this, context); + return visitor.visitReference(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Row.java b/core/trino-main/src/main/java/io/trino/sql/ir/Row.java index b3f4eccfc29e..b8d14a919d45 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Row.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Row.java @@ -39,12 +39,6 @@ public Type type() return RowType.anonymous(items.stream().map(Expression::type).collect(Collectors.toList())); } - @Deprecated - public List getItems() - { - return items; - } - @Override public R accept(IrVisitor visitor, C context) { @@ -52,7 +46,7 @@ public R accept(IrVisitor visitor, C context) } @Override - public List getChildren() + public List children() { return items; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/SubscriptExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java similarity index 75% rename from core/trino-main/src/main/java/io/trino/sql/ir/SubscriptExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java index a22d72717430..5319298861c5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/SubscriptExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java @@ -22,10 +22,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record SubscriptExpression(Type type, Expression base, Expression index) +public record Subscript(Type type, Expression base, Expression index) implements Expression { - public SubscriptExpression + public Subscript { requireNonNull(base, "base is null"); requireNonNull(index, "index is null"); @@ -40,24 +40,12 @@ public Type type() @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitSubscriptExpression(this, context); + return visitor.visitSubscript(this, context); } @Override - public List getChildren() + public List children() { return ImmutableList.of(base, index); } - - @Deprecated - public Expression getBase() - { - return base; - } - - @Deprecated - public Expression getIndex() - { - return index; - } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/SimpleCaseExpression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Switch.java similarity index 74% rename from core/trino-main/src/main/java/io/trino/sql/ir/SimpleCaseExpression.java rename to core/trino-main/src/main/java/io/trino/sql/ir/Switch.java index 765ffa16e187..75abbaf459b9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/SimpleCaseExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Switch.java @@ -23,10 +23,10 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record SimpleCaseExpression(Expression operand, List whenClauses, Optional defaultValue) +public record Switch(Expression operand, List whenClauses, Optional defaultValue) implements Expression { - public SimpleCaseExpression + public Switch { requireNonNull(operand, "operand is null"); whenClauses = ImmutableList.copyOf(whenClauses); @@ -39,32 +39,14 @@ public Type type() return whenClauses.getFirst().getResult().type(); } - @Deprecated - public Expression getOperand() - { - return operand; - } - - @Deprecated - public List getWhenClauses() - { - return whenClauses; - } - - @Deprecated - public Optional getDefaultValue() - { - return defaultValue; - } - @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitSimpleCaseExpression(this, context); + return visitor.visitSwitch(this, context); } @Override - public List getChildren() + public List children() { ImmutableList.Builder builder = ImmutableList.builder() .add(operand); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java b/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java index 76629948cf9e..0c75a7ba231b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java @@ -18,9 +18,9 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import java.util.ArrayList; import java.util.List; @@ -54,7 +54,7 @@ public BuiltinFunctionCallBuilder setName(String name) public BuiltinFunctionCallBuilder addArgument(Constant value) { requireNonNull(value, "value is null"); - return addArgument(value.getType().getTypeSignature(), value); + return addArgument(value.type().getTypeSignature(), value); } public BuiltinFunctionCallBuilder addArgument(Type type, Expression value) @@ -83,9 +83,9 @@ public BuiltinFunctionCallBuilder setArguments(List types, List EQUAL_OPERATOR_FUNCTION_NAME; @@ -162,7 +161,7 @@ static FunctionName functionNameForComparisonOperator(ComparisonExpression.Opera } @VisibleForTesting - static FunctionName functionNameForArithmeticBinaryOperator(ArithmeticBinaryExpression.Operator operator) + static FunctionName functionNameForArithmeticBinaryOperator(Arithmetic.Operator operator) { return switch (operator) { case ADD -> ADD_FUNCTION_NAME; @@ -208,17 +207,17 @@ public Optional translate(ConnectorExpression expression) if (expression instanceof FieldDereference dereference) { return translate(dereference.getTarget()) - .map(base -> new SubscriptExpression(dereference.getType(), base, new Constant(INTEGER, (long) (dereference.getField() + 1)))); + .map(base -> new Subscript(dereference.getType(), base, new Constant(INTEGER, (long) (dereference.getField() + 1)))); } - if (expression instanceof Call) { - return translateCall((Call) expression); + if (expression instanceof io.trino.spi.expression.Call) { + return translateCall((io.trino.spi.expression.Call) expression); } return Optional.empty(); } - protected Optional translateCall(Call call) + protected Optional translateCall(io.trino.spi.expression.Call call) { if (call.getFunctionName().getCatalogSchema().isPresent()) { CatalogSchemaName catalogSchemaName = call.getFunctionName().getCatalogSchema().get(); @@ -234,15 +233,15 @@ protected Optional translateCall(Call call) } if (AND_FUNCTION_NAME.equals(call.getFunctionName())) { - return translateLogicalExpression(LogicalExpression.Operator.AND, call.getArguments()); + return translateLogicalExpression(Logical.Operator.AND, call.getArguments()); } if (OR_FUNCTION_NAME.equals(call.getFunctionName())) { - return translateLogicalExpression(LogicalExpression.Operator.OR, call.getArguments()); + return translateLogicalExpression(Logical.Operator.OR, call.getArguments()); } if (NOT_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { ConnectorExpression expression = getOnlyElement(call.getArguments()); - if (expression instanceof Call innerCall) { + if (expression instanceof io.trino.spi.expression.Call innerCall) { if (innerCall.getFunctionName().equals(IS_NULL_FUNCTION_NAME) && innerCall.getArguments().size() == 1) { return translateIsNotNull(innerCall.getArguments().get(0)); } @@ -263,7 +262,7 @@ protected Optional translateCall(Call call) // comparisons if (call.getArguments().size() == 2) { - Optional operator = comparisonOperatorForFunctionName(call.getFunctionName()); + Optional operator = comparisonOperatorForFunctionName(call.getFunctionName()); if (operator.isPresent()) { return translateComparison(operator.get(), call.getArguments().get(0), call.getArguments().get(1)); } @@ -271,7 +270,7 @@ protected Optional translateCall(Call call) // arithmetic binary if (call.getArguments().size() == 2) { - Optional operator = arithmeticBinaryOperatorForFunctionName(call.getFunctionName()); + Optional operator = arithmeticBinaryOperatorForFunctionName(call.getFunctionName()); if (operator.isPresent()) { return translateArithmeticBinary(operator.get(), call.getArguments().get(0), call.getArguments().get(1)); } @@ -279,7 +278,7 @@ protected Optional translateCall(Call call) // arithmetic unary if (NEGATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { - return translate(getOnlyElement(call.getArguments())).map(argument -> new ArithmeticNegation(argument)); + return translate(getOnlyElement(call.getArguments())).map(argument -> new Negation(argument)); } if (StandardFunctions.LIKE_FUNCTION_NAME.equals(call.getFunctionName())) { @@ -331,7 +330,7 @@ private Optional translateIsNotNull(ConnectorExpression argument) { Optional translatedArgument = translate(argument); if (translatedArgument.isPresent()) { - return Optional.of(new NotExpression(new IsNullPredicate(translatedArgument.get()))); + return Optional.of(new Not(new IsNull(translatedArgument.get()))); } return Optional.empty(); @@ -341,7 +340,7 @@ private Optional translateIsNull(ConnectorExpression argument) { Optional translatedArgument = translate(argument); if (translatedArgument.isPresent()) { - return Optional.of(new IsNullPredicate(translatedArgument.get())); + return Optional.of(new IsNull(translatedArgument.get())); } return Optional.empty(); @@ -351,7 +350,7 @@ private Optional translateNot(ConnectorExpression argument) { Optional translatedArgument = translate(argument); if (argument.getType().equals(BOOLEAN) && translatedArgument.isPresent()) { - return Optional.of(new NotExpression(translatedArgument.get())); + return Optional.of(new Not(translatedArgument.get())); } return Optional.empty(); } @@ -367,17 +366,17 @@ private Optional translateCast(Type type, ConnectorExpression expres return Optional.empty(); } - private Optional translateLogicalExpression(LogicalExpression.Operator operator, List arguments) + private Optional translateLogicalExpression(Logical.Operator operator, List arguments) { Optional> translatedArguments = translateExpressions(arguments); - return translatedArguments.map(expressions -> new LogicalExpression(operator, expressions)); + return translatedArguments.map(expressions -> new Logical(operator, expressions)); } - private Optional translateComparison(ComparisonExpression.Operator operator, ConnectorExpression left, ConnectorExpression right) + private Optional translateComparison(Comparison.Operator operator, ConnectorExpression left, ConnectorExpression right) { return translate(left).flatMap(leftTranslated -> translate(right).map(rightTranslated -> - new ComparisonExpression(operator, leftTranslated, rightTranslated))); + new Comparison(operator, leftTranslated, rightTranslated))); } private Optional translateNullIf(ConnectorExpression first, ConnectorExpression second) @@ -385,39 +384,39 @@ private Optional translateNullIf(ConnectorExpression first, Connecto Optional firstExpression = translate(first); Optional secondExpression = translate(second); if (firstExpression.isPresent() && secondExpression.isPresent()) { - return Optional.of(new NullIfExpression(firstExpression.get(), secondExpression.get())); + return Optional.of(new NullIf(firstExpression.get(), secondExpression.get())); } return Optional.empty(); } - private Optional comparisonOperatorForFunctionName(FunctionName functionName) + private Optional comparisonOperatorForFunctionName(FunctionName functionName) { if (EQUAL_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.EQUAL); + return Optional.of(Comparison.Operator.EQUAL); } if (NOT_EQUAL_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.NOT_EQUAL); + return Optional.of(Comparison.Operator.NOT_EQUAL); } if (LESS_THAN_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.LESS_THAN); + return Optional.of(Comparison.Operator.LESS_THAN); } if (LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL); + return Optional.of(Comparison.Operator.LESS_THAN_OR_EQUAL); } if (GREATER_THAN_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.GREATER_THAN); + return Optional.of(Comparison.Operator.GREATER_THAN); } if (GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL); + return Optional.of(Comparison.Operator.GREATER_THAN_OR_EQUAL); } if (IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ComparisonExpression.Operator.IS_DISTINCT_FROM); + return Optional.of(Comparison.Operator.IS_DISTINCT_FROM); } return Optional.empty(); } - private Optional translateArithmeticBinary(ArithmeticBinaryExpression.Operator operator, ConnectorExpression left, ConnectorExpression right) + private Optional translateArithmeticBinary(Arithmetic.Operator operator, ConnectorExpression left, ConnectorExpression right) { OperatorType operatorType = switch (operator) { case ADD -> OperatorType.ADD; @@ -430,25 +429,25 @@ private Optional translateArithmeticBinary(ArithmeticBinaryExpressio return translate(left).flatMap(leftTranslated -> translate(right).map(rightTranslated -> - new ArithmeticBinaryExpression(function, operator, leftTranslated, rightTranslated))); + new Arithmetic(function, operator, leftTranslated, rightTranslated))); } - private Optional arithmeticBinaryOperatorForFunctionName(FunctionName functionName) + private Optional arithmeticBinaryOperatorForFunctionName(FunctionName functionName) { if (ADD_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ArithmeticBinaryExpression.Operator.ADD); + return Optional.of(Arithmetic.Operator.ADD); } if (SUBTRACT_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ArithmeticBinaryExpression.Operator.SUBTRACT); + return Optional.of(Arithmetic.Operator.SUBTRACT); } if (MULTIPLY_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ArithmeticBinaryExpression.Operator.MULTIPLY); + return Optional.of(Arithmetic.Operator.MULTIPLY); } if (DIVIDE_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ArithmeticBinaryExpression.Operator.DIVIDE); + return Optional.of(Arithmetic.Operator.DIVIDE); } if (MODULUS_FUNCTION_NAME.equals(functionName)) { - return Optional.of(ArithmeticBinaryExpression.Operator.MODULUS); + return Optional.of(Arithmetic.Operator.MODULUS); } return Optional.empty(); } @@ -459,7 +458,7 @@ protected Optional translateLike(ConnectorExpression value, Connecto Optional translatedPattern = translate(pattern); if (translatedValue.isPresent() && translatedPattern.isPresent()) { - FunctionCall patternCall; + Call patternCall; if (escape.isPresent()) { Optional translatedEscape = translate(escape.get()); if (translatedEscape.isEmpty()) { @@ -479,7 +478,7 @@ protected Optional translateLike(ConnectorExpression value, Connecto .build(); } - FunctionCall call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) + Call call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName(LIKE_FUNCTION_NAME) .addArgument(value.getType(), translatedValue.get()) .addArgument(LIKE_PATTERN, patternCall) @@ -497,7 +496,7 @@ protected Optional translateInPredicate(ConnectorExpression value, C Optional> translatedValues = extractExpressionsFromArrayCall(values); if (translatedValue.isPresent() && translatedValues.isPresent()) { - return Optional.of(new InPredicate(translatedValue.get(), translatedValues.get())); + return Optional.of(new In(translatedValue.get(), translatedValues.get())); } return Optional.empty(); @@ -505,7 +504,7 @@ protected Optional translateInPredicate(ConnectorExpression value, C protected Optional> extractExpressionsFromArrayCall(ConnectorExpression expression) { - if (!(expression instanceof Call call)) { + if (!(expression instanceof io.trino.spi.expression.Call call)) { return Optional.empty(); } @@ -542,83 +541,83 @@ public SqlToConnectorExpressionTranslator(Session session) } @Override - protected Optional visitSymbolReference(SymbolReference node, Void context) + protected Optional visitReference(Reference node, Void context) { - return Optional.of(new Variable(node.getName(), ((Expression) node).type())); + return Optional.of(new Variable(node.name(), ((Expression) node).type())); } @Override protected Optional visitConstant(Constant node, Void context) { - return Optional.of(constantFor(node.getType(), node.getValue())); + return Optional.of(constantFor(node.type(), node.value())); } @Override - protected Optional visitLogicalExpression(LogicalExpression node, Void context) + protected Optional visitLogical(Logical node, Void context) { if (!isComplexExpressionPushdown(session)) { return Optional.empty(); } - ImmutableList.Builder arguments = ImmutableList.builderWithExpectedSize(node.getTerms().size()); - for (Expression argument : node.getTerms()) { + ImmutableList.Builder arguments = ImmutableList.builderWithExpectedSize(node.terms().size()); + for (Expression argument : node.terms()) { Optional translated = process(argument); if (translated.isEmpty()) { return Optional.empty(); } arguments.add(translated.get()); } - return switch (node.getOperator()) { - case AND -> Optional.of(new Call(BOOLEAN, AND_FUNCTION_NAME, arguments.build())); - case OR -> Optional.of(new Call(BOOLEAN, OR_FUNCTION_NAME, arguments.build())); + return switch (node.operator()) { + case AND -> Optional.of(new io.trino.spi.expression.Call(BOOLEAN, AND_FUNCTION_NAME, arguments.build())); + case OR -> Optional.of(new io.trino.spi.expression.Call(BOOLEAN, OR_FUNCTION_NAME, arguments.build())); }; } @Override - protected Optional visitComparisonExpression(ComparisonExpression node, Void context) + protected Optional visitComparison(Comparison node, Void context) { if (!isComplexExpressionPushdown(session)) { return Optional.empty(); } - return process(node.getLeft()).flatMap(left -> process(node.getRight()).map(right -> - new Call(((Expression) node).type(), functionNameForComparisonOperator(node.getOperator()), ImmutableList.of(left, right)))); + return process(node.left()).flatMap(left -> process(node.right()).map(right -> + new io.trino.spi.expression.Call(((Expression) node).type(), functionNameForComparisonOperator(node.operator()), ImmutableList.of(left, right)))); } @Override - protected Optional visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + protected Optional visitArithmetic(Arithmetic node, Void context) { if (!isComplexExpressionPushdown(session)) { return Optional.empty(); } - return process(node.getLeft()).flatMap(left -> process(node.getRight()).map(right -> - new Call(((Expression) node).type(), functionNameForArithmeticBinaryOperator(node.getOperator()), ImmutableList.of(left, right)))); + return process(node.left()).flatMap(left -> process(node.right()).map(right -> + new io.trino.spi.expression.Call(((Expression) node).type(), functionNameForArithmeticBinaryOperator(node.operator()), ImmutableList.of(left, right)))); } @Override - protected Optional visitBetweenPredicate(BetweenPredicate node, Void context) + protected Optional visitBetween(Between node, Void context) { if (!isComplexExpressionPushdown(session)) { return Optional.empty(); } - return process(node.getValue()).flatMap(value -> - process(node.getMin()).flatMap(min -> - process(node.getMax()).map(max -> - new Call( + return process(node.value()).flatMap(value -> + process(node.min()).flatMap(min -> + process(node.max()).map(max -> + new io.trino.spi.expression.Call( BOOLEAN, AND_FUNCTION_NAME, ImmutableList.of( - new Call(BOOLEAN, GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, min)), - new Call(BOOLEAN, LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, max))))))); + new io.trino.spi.expression.Call(BOOLEAN, GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, min)), + new io.trino.spi.expression.Call(BOOLEAN, LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, max))))))); } @Override - protected Optional visitArithmeticNegation(ArithmeticNegation node, Void context) + protected Optional visitNegation(Negation node, Void context) { if (!isComplexExpressionPushdown(session)) { return Optional.empty(); } - return process(node.getValue()).map(value -> new Call(((Expression) node).type(), NEGATE_FUNCTION_NAME, ImmutableList.of(value))); + return process(node.value()).map(value -> new io.trino.spi.expression.Call(((Expression) node).type(), NEGATE_FUNCTION_NAME, ImmutableList.of(value))); } @Override @@ -631,7 +630,7 @@ protected Optional visitCast(Cast node, Void context) return Optional.empty(); } - if (node.isSafe()) { + if (node.safe()) { // try_cast would need to be modeled separately return Optional.empty(); } @@ -640,22 +639,22 @@ protected Optional visitCast(Cast node, Void context) return Optional.empty(); } - Optional translatedExpression = process(node.getExpression()); + Optional translatedExpression = process(node.expression()); if (translatedExpression.isPresent()) { - return Optional.of(new Call(node.getType(), CAST_FUNCTION_NAME, List.of(translatedExpression.get()))); + return Optional.of(new io.trino.spi.expression.Call(node.type(), CAST_FUNCTION_NAME, List.of(translatedExpression.get()))); } return Optional.empty(); } @Override - protected Optional visitFunctionCall(FunctionCall node, Void context) + protected Optional visitCall(Call node, Void context) { if (!isComplexExpressionPushdown(session)) { return Optional.empty(); } - CatalogSchemaFunctionName functionName = node.getFunction().getName(); + CatalogSchemaFunctionName functionName = node.function().getName(); checkArgument(!isDynamicFilterFunction(functionName), "Dynamic filter has no meaning for a connector, it should not be translated into ConnectorExpression"); if (functionName.equals(builtinFunctionName(LIKE_FUNCTION_NAME))) { @@ -663,7 +662,7 @@ protected Optional visitFunctionCall(FunctionCall node, Voi } ImmutableList.Builder arguments = ImmutableList.builder(); - for (Expression argumentExpression : node.getArguments()) { + for (Expression argumentExpression : node.arguments()) { Optional argument = process(argumentExpression); if (argument.isEmpty()) { return Optional.empty(); @@ -681,40 +680,40 @@ protected Optional visitFunctionCall(FunctionCall node, Voi else { name = new FunctionName(Optional.of(new CatalogSchemaName(functionName.getCatalogName(), functionName.getSchemaName())), functionName.getFunctionName()); } - return Optional.of(new Call(((Expression) node).type(), name, arguments.build())); + return Optional.of(new io.trino.spi.expression.Call(((Expression) node).type(), name, arguments.build())); } - private Optional translateLike(FunctionCall node) + private Optional translateLike(Call node) { // we need special handling for LIKE because within the engine IR a LIKE expression // is modeled as $like(value, $like_pattern(pattern, escape)) and we want // to expose it to connectors as if if were $like(value, pattern, escape) ImmutableList.Builder arguments = ImmutableList.builder(); - Optional value = process(node.getArguments().get(0)); + Optional value = process(node.arguments().get(0)); if (value.isEmpty()) { return Optional.empty(); } arguments.add(value.get()); - Expression patternArgument = node.getArguments().get(1); + Expression patternArgument = node.arguments().get(1); if (patternArgument instanceof Constant constant) { - LikePattern matcher = (LikePattern) constant.getValue(); + LikePattern matcher = (LikePattern) constant.value(); arguments.add(new io.trino.spi.expression.Constant(Slices.utf8Slice(matcher.getPattern()), createVarcharType(matcher.getPattern().length()))); if (matcher.getEscape().isPresent()) { arguments.add(new io.trino.spi.expression.Constant(Slices.utf8Slice(matcher.getEscape().get().toString()), createVarcharType(1))); } } - else if (patternArgument instanceof FunctionCall call && call.getFunction().getName().equals(builtinFunctionName(LIKE_PATTERN_FUNCTION_NAME))) { - Optional translatedPattern = process(call.getArguments().get(0)); + else if (patternArgument instanceof Call call && call.function().getName().equals(builtinFunctionName(LIKE_PATTERN_FUNCTION_NAME))) { + Optional translatedPattern = process(call.arguments().get(0)); if (translatedPattern.isEmpty()) { return Optional.empty(); } arguments.add(translatedPattern.get()); - if (call.getArguments().size() == 2) { - Optional translatedEscape = process(call.getArguments().get(1)); + if (call.arguments().size() == 2) { + Optional translatedEscape = process(call.arguments().get(1)); if (translatedEscape.isEmpty()) { return Optional.empty(); } @@ -725,25 +724,25 @@ else if (patternArgument instanceof FunctionCall call && call.getFunction().getN return Optional.empty(); } - return Optional.of(new Call(((Expression) node).type(), StandardFunctions.LIKE_FUNCTION_NAME, arguments.build())); + return Optional.of(new io.trino.spi.expression.Call(((Expression) node).type(), StandardFunctions.LIKE_FUNCTION_NAME, arguments.build())); } @Override - protected Optional visitIsNullPredicate(IsNullPredicate node, Void context) + protected Optional visitIsNull(IsNull node, Void context) { - Optional translatedValue = process(node.getValue()); + Optional translatedValue = process(node.value()); if (translatedValue.isPresent()) { - return Optional.of(new Call(BOOLEAN, IS_NULL_FUNCTION_NAME, ImmutableList.of(translatedValue.get()))); + return Optional.of(new io.trino.spi.expression.Call(BOOLEAN, IS_NULL_FUNCTION_NAME, ImmutableList.of(translatedValue.get()))); } return Optional.empty(); } @Override - protected Optional visitNotExpression(NotExpression node, Void context) + protected Optional visitNot(Not node, Void context) { - Optional translatedValue = process(node.getValue()); + Optional translatedValue = process(node.value()); if (translatedValue.isPresent()) { - return Optional.of(new Call(BOOLEAN, NOT_FUNCTION_NAME, List.of(translatedValue.get()))); + return Optional.of(new io.trino.spi.expression.Call(BOOLEAN, NOT_FUNCTION_NAME, List.of(translatedValue.get()))); } return Optional.empty(); } @@ -773,42 +772,42 @@ private ConnectorExpression constantFor(Type type, Object value) } @Override - protected Optional visitNullIfExpression(NullIfExpression node, Void context) + protected Optional visitNullIf(NullIf node, Void context) { - Optional firstValue = process(node.getFirst()); - Optional secondValue = process(node.getSecond()); + Optional firstValue = process(node.first()); + Optional secondValue = process(node.second()); if (firstValue.isPresent() && secondValue.isPresent()) { - return Optional.of(new Call(((Expression) node).type(), NULLIF_FUNCTION_NAME, ImmutableList.of(firstValue.get(), secondValue.get()))); + return Optional.of(new io.trino.spi.expression.Call(((Expression) node).type(), NULLIF_FUNCTION_NAME, ImmutableList.of(firstValue.get(), secondValue.get()))); } return Optional.empty(); } @Override - protected Optional visitSubscriptExpression(SubscriptExpression node, Void context) + protected Optional visitSubscript(Subscript node, Void context) { - if (!(node.getBase().type() instanceof RowType)) { + if (!(node.base().type() instanceof RowType)) { return Optional.empty(); } - Optional translatedBase = process(node.getBase()); + Optional translatedBase = process(node.base()); if (translatedBase.isEmpty()) { return Optional.empty(); } - return Optional.of(new FieldDereference(((Expression) node).type(), translatedBase.get(), (int) ((long) ((Constant) node.getIndex()).getValue() - 1))); + return Optional.of(new FieldDereference(((Expression) node).type(), translatedBase.get(), (int) ((long) ((Constant) node.index()).value() - 1))); } @Override - protected Optional visitInPredicate(InPredicate node, Void context) + protected Optional visitIn(In node, Void context) { - Optional valueExpression = process(node.getValue()); + Optional valueExpression = process(node.value()); if (valueExpression.isEmpty()) { return Optional.empty(); } - ImmutableList.Builder values = ImmutableList.builderWithExpectedSize(node.getValueList().size()); - for (Expression value : node.getValueList()) { + ImmutableList.Builder values = ImmutableList.builderWithExpectedSize(node.valueList().size()); + for (Expression value : node.valueList()) { // TODO: NULL should be eliminated on the engine side (within a rule) if (value == null) { return Optional.empty(); @@ -823,8 +822,8 @@ protected Optional visitInPredicate(InPredicate node, Void values.add(processedValue.get()); } - ConnectorExpression arrayExpression = new Call(new ArrayType(node.getValue().type()), ARRAY_CONSTRUCTOR_FUNCTION_NAME, values.build()); - return Optional.of(new Call(((Expression) node).type(), IN_PREDICATE_FUNCTION_NAME, List.of(valueExpression.get(), arrayExpression))); + ConnectorExpression arrayExpression = new io.trino.spi.expression.Call(new ArrayType(node.value().type()), ARRAY_CONSTRUCTOR_FUNCTION_NAME, values.build()); + return Optional.of(new io.trino.spi.expression.Call(((Expression) node).type(), IN_PREDICATE_FUNCTION_NAME, List.of(valueExpression.get(), arrayExpression))); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java index 959f91411270..81b8064efa31 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java @@ -13,9 +13,9 @@ */ package io.trino.sql.planner; +import io.trino.sql.ir.Call; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import java.util.concurrent.atomic.AtomicBoolean; @@ -37,13 +37,13 @@ private static class Visitor extends DefaultTraversalVisitor { @Override - protected Void visitFunctionCall(FunctionCall node, AtomicBoolean deterministic) + protected Void visitCall(Call node, AtomicBoolean deterministic) { - if (!node.getFunction().isDeterministic()) { + if (!node.function().isDeterministic()) { deterministic.set(false); return null; } - return super.visitFunctionCall(node, deterministic); + return super.visitCall(node, deterministic); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index bfcf250d3795..8e6f9d1c16fa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -39,18 +39,18 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; +import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.type.LikeFunctions; import io.trino.type.LikePattern; import io.trino.type.TypeCoercion; @@ -85,15 +85,15 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.TypeUtils.isFloatingPointNaN; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.combineDisjunctsWithDefault; @@ -108,7 +108,7 @@ public final class DomainTranslator public Expression toPredicate(TupleDomain tupleDomain) { if (tupleDomain.isNone()) { - return FALSE_LITERAL; + return FALSE; } Map domains = tupleDomain.getDomains().get(); @@ -117,14 +117,14 @@ public Expression toPredicate(TupleDomain tupleDomain) .collect(collectingAndThen(toImmutableList(), expressions -> combineConjuncts(expressions))); } - private Expression toPredicate(Domain domain, SymbolReference reference) + private Expression toPredicate(Domain domain, Reference reference) { if (domain.getValues().isNone()) { - return domain.isNullAllowed() ? new IsNullPredicate(reference) : FALSE_LITERAL; + return domain.isNullAllowed() ? new IsNull(reference) : FALSE; } if (domain.getValues().isAll()) { - return domain.isNullAllowed() ? TRUE_LITERAL : new NotExpression(new IsNullPredicate(reference)); + return domain.isNullAllowed() ? TRUE : new Not(new IsNull(reference)); } List disjuncts = new ArrayList<>(); @@ -138,21 +138,21 @@ private Expression toPredicate(Domain domain, SymbolReference reference) // Add nullability disjuncts if (domain.isNullAllowed()) { - disjuncts.add(new IsNullPredicate(reference)); + disjuncts.add(new IsNull(reference)); } - return combineDisjunctsWithDefault(disjuncts, TRUE_LITERAL); + return combineDisjunctsWithDefault(disjuncts, TRUE); } - private Expression processRange(Type type, Range range, SymbolReference reference) + private Expression processRange(Type type, Range range, Reference reference) { if (range.isAll()) { - return TRUE_LITERAL; + return TRUE; } if (isBetween(range)) { // specialize the range with BETWEEN expression if possible b/c it is currently more efficient - return new BetweenPredicate( + return new Between( reference, new Constant(type, range.getLowBoundedValue()), new Constant(type, range.getHighBoundedValue())); @@ -160,13 +160,13 @@ private Expression processRange(Type type, Range range, SymbolReference referenc List rangeConjuncts = new ArrayList<>(); if (!range.isLowUnbounded()) { - rangeConjuncts.add(new ComparisonExpression( + rangeConjuncts.add(new Comparison( range.isLowInclusive() ? GREATER_THAN_OR_EQUAL : GREATER_THAN, reference, new Constant(type, range.getLowBoundedValue()))); } if (!range.isHighUnbounded()) { - rangeConjuncts.add(new ComparisonExpression( + rangeConjuncts.add(new Comparison( range.isHighInclusive() ? LESS_THAN_OR_EQUAL : LESS_THAN, reference, new Constant(type, range.getHighBoundedValue()))); @@ -176,21 +176,21 @@ private Expression processRange(Type type, Range range, SymbolReference referenc return combineConjuncts(rangeConjuncts); } - private Expression combineRangeWithExcludedPoints(Type type, SymbolReference reference, Range range, List excludedPoints) + private Expression combineRangeWithExcludedPoints(Type type, Reference reference, Range range, List excludedPoints) { if (excludedPoints.isEmpty()) { return processRange(type, range, reference); } - Expression excludedPointsExpression = new NotExpression(new InPredicate(reference, excludedPoints)); + Expression excludedPointsExpression = new Not(new In(reference, excludedPoints)); if (excludedPoints.size() == 1) { - excludedPointsExpression = new ComparisonExpression(NOT_EQUAL, reference, getOnlyElement(excludedPoints)); + excludedPointsExpression = new Comparison(NOT_EQUAL, reference, getOnlyElement(excludedPoints)); } return combineConjuncts(processRange(type, range, reference), excludedPointsExpression); } - private List extractDisjuncts(Type type, Ranges ranges, SymbolReference reference) + private List extractDisjuncts(Type type, Ranges ranges, Reference reference) { List disjuncts = new ArrayList<>(); List singleValues = new ArrayList<>(); @@ -244,15 +244,15 @@ private List extractDisjuncts(Type type, Ranges ranges, SymbolRefere // Add back all of the possible single values either as an equality or an IN predicate if (singleValues.size() == 1) { - disjuncts.add(new ComparisonExpression(EQUAL, reference, getOnlyElement(singleValues))); + disjuncts.add(new Comparison(EQUAL, reference, getOnlyElement(singleValues))); } else if (singleValues.size() > 1) { - disjuncts.add(new InPredicate(reference, singleValues)); + disjuncts.add(new In(reference, singleValues)); } return disjuncts; } - private List extractDisjuncts(Type type, DiscreteValues discreteValues, SymbolReference reference) + private List extractDisjuncts(Type type, DiscreteValues discreteValues, Reference reference) { List values = discreteValues.getValues().stream() .map(object -> new Constant(type, object)) @@ -263,14 +263,14 @@ private List extractDisjuncts(Type type, DiscreteValues discreteValu Expression predicate; if (values.size() == 1) { - predicate = new ComparisonExpression(EQUAL, reference, getOnlyElement(values)); + predicate = new Comparison(EQUAL, reference, getOnlyElement(values)); } else { - predicate = new InPredicate(reference, values); + predicate = new In(reference, values); } if (!discreteValues.isInclusive()) { - predicate = new NotExpression(predicate); + predicate = new Not(predicate); } return ImmutableList.of(predicate); } @@ -321,7 +321,7 @@ private static Domain complementIfNecessary(Domain domain, boolean complement) private static Expression complementIfNecessary(Expression expression, boolean complement) { - return complement ? new NotExpression(expression) : expression; + return complement ? new Not(expression) : expression; } @Override @@ -332,9 +332,9 @@ protected ExtractionResult visitExpression(Expression node, Boolean complement) } @Override - protected ExtractionResult visitLogicalExpression(LogicalExpression node, Boolean complement) + protected ExtractionResult visitLogical(Logical node, Boolean complement) { - List results = node.getTerms().stream() + List results = node.terms().stream() .map(term -> process(term, complement)) .collect(toImmutableList()); @@ -346,7 +346,7 @@ protected ExtractionResult visitLogicalExpression(LogicalExpression node, Boolea .map(ExtractionResult::getRemainingExpression) .collect(toImmutableList()); - LogicalExpression.Operator operator = complement ? node.getOperator().flip() : node.getOperator(); + Logical.Operator operator = complement ? node.operator().flip() : node.operator(); switch (operator) { case AND: return new ExtractionResult( @@ -418,48 +418,48 @@ else if (matchingSingleSymbolDomains) { return new ExtractionResult(columnUnionedTupleDomain, remainingExpression); } - throw new AssertionError("Unknown operator: " + node.getOperator()); + throw new AssertionError("Unknown operator: " + node.operator()); } @Override - protected ExtractionResult visitNotExpression(NotExpression node, Boolean complement) + protected ExtractionResult visitNot(Not node, Boolean complement) { - return process(node.getValue(), !complement); + return process(node.value(), !complement); } @Override - protected ExtractionResult visitSymbolReference(SymbolReference node, Boolean complement) + protected ExtractionResult visitReference(Reference node, Boolean complement) { if (node.type().equals(BOOLEAN)) { - ComparisonExpression newNode = new ComparisonExpression(EQUAL, node, TRUE_LITERAL); - return visitComparisonExpression(newNode, complement); + Comparison newNode = new Comparison(EQUAL, node, TRUE); + return visitComparison(newNode, complement); } return visitExpression(node, complement); } @Override - protected ExtractionResult visitComparisonExpression(ComparisonExpression node, Boolean complement) + protected ExtractionResult visitComparison(Comparison node, Boolean complement) { Optional optionalNormalized = toNormalizedSimpleComparison(node); if (optionalNormalized.isEmpty()) { - return super.visitComparisonExpression(node, complement); + return super.visitComparison(node, complement); } NormalizedSimpleComparison normalized = optionalNormalized.get(); Expression symbolExpression = normalized.getSymbolExpression(); - if (symbolExpression instanceof SymbolReference) { + if (symbolExpression instanceof Reference) { Symbol symbol = Symbol.from(symbolExpression); NullableValue value = normalized.getValue(); Type type = value.getType(); // common type for symbol and value return createComparisonExtractionResult(normalized.getComparisonOperator(), symbol, type, value.getValue(), complement) - .orElseGet(() -> super.visitComparisonExpression(node, complement)); + .orElseGet(() -> super.visitComparison(node, complement)); } if (symbolExpression instanceof Cast castExpression) { // type of expression which is then cast to type of value Type castSourceType = castExpression.expression().type(); Type castTargetType = castExpression.type(); - if (castSourceType instanceof VarcharType && castTargetType == DATE && !castExpression.isSafe()) { + if (castSourceType instanceof VarcharType && castTargetType == DATE && !castExpression.safe()) { Optional result = createVarcharCastToDateComparisonExtractionResult( normalized, (VarcharType) castSourceType, @@ -488,29 +488,29 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, // end up with tuple domain with just one time point (cast(date_literal as timestamp). // While we need range which maps to single date pointed by date_literal. // - return super.visitComparisonExpression(node, complement); + return super.visitComparison(node, complement); } // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side Optional coercedExpression = coerceComparisonWithRounding( - castSourceType, castExpression.getExpression(), normalized.getValue(), normalized.getComparisonOperator()); + castSourceType, castExpression.expression(), normalized.getValue(), normalized.getComparisonOperator()); if (coercedExpression.isPresent()) { return process(coercedExpression.get(), complement); } - return super.visitComparisonExpression(node, complement); + return super.visitComparison(node, complement); } - return super.visitComparisonExpression(node, complement); + return super.visitComparison(node, complement); } /** * Extract a normalized simple comparison between a QualifiedNameReference and a native value if possible. */ - private Optional toNormalizedSimpleComparison(ComparisonExpression comparison) + private Optional toNormalizedSimpleComparison(Comparison comparison) { - Expression left = comparison.getLeft(); - Expression right = comparison.getRight(); + Expression left = comparison.left(); + Expression right = comparison.right(); if (left instanceof Constant == right instanceof Constant) { // One of the terms must be a constant and the other a non-constant @@ -518,10 +518,10 @@ private Optional toNormalizedSimpleComparison(Compar } if (left instanceof Constant constant) { - return Optional.of(new NormalizedSimpleComparison(right, comparison.getOperator().flip(), new NullableValue(left.type(), constant.getValue()))); + return Optional.of(new NormalizedSimpleComparison(right, comparison.operator().flip(), new NullableValue(left.type(), constant.value()))); } else { - return Optional.of(new NormalizedSimpleComparison(left, comparison.getOperator(), new NullableValue(right.type(), ((Constant) right).getValue()))); + return Optional.of(new NormalizedSimpleComparison(left, comparison.operator(), new NullableValue(right.type(), ((Constant) right).value()))); } } @@ -534,16 +534,16 @@ private Optional createVarcharCastToDateComparisonExtractionRe NormalizedSimpleComparison comparison, VarcharType sourceType, boolean complement, - ComparisonExpression originalExpression) + Comparison originalExpression) { - Expression sourceExpression = ((Cast) comparison.getSymbolExpression()).getExpression(); - ComparisonExpression.Operator operator = comparison.getComparisonOperator(); + Expression sourceExpression = ((Cast) comparison.getSymbolExpression()).expression(); + Comparison.Operator operator = comparison.getComparisonOperator(); NullableValue value = comparison.getValue(); if (complement || value.isNull()) { return Optional.empty(); } - if (!(sourceExpression instanceof SymbolReference)) { + if (!(sourceExpression instanceof Reference)) { // Calculation is not useful return Optional.empty(); } @@ -638,7 +638,7 @@ private static SortedRangeSet dateStringRanges(LocalDate date, VarcharType domai return (SortedRangeSet) ValueSet.ofRanges(valueRanges); } - private static Optional createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) + private static Optional createComparisonExtractionResult(Comparison.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) { if (value == null) { switch (comparisonOperator) { @@ -648,30 +648,30 @@ private static Optional createComparisonExtractionResult(Compa case LESS_THAN: case LESS_THAN_OR_EQUAL: case NOT_EQUAL: - return Optional.of(new ExtractionResult(TupleDomain.none(), TRUE_LITERAL)); + return Optional.of(new ExtractionResult(TupleDomain.none(), TRUE)); case IS_DISTINCT_FROM: Domain domain = complementIfNecessary(Domain.notNull(type), complement); return Optional.of(new ExtractionResult( TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), - TRUE_LITERAL)); + TRUE)); } throw new AssertionError("Unhandled operator: " + comparisonOperator); } if (type.isOrderable()) { return extractOrderableDomain(comparisonOperator, type, value, complement) - .map(domain -> new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), TRUE_LITERAL)); + .map(domain -> new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), TRUE)); } if (type.isComparable()) { Domain domain = extractEquatableDomain(comparisonOperator, type, value, complement); return Optional.of(new ExtractionResult( TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), - TRUE_LITERAL)); + TRUE)); } throw new AssertionError("Type cannot be used in a comparison expression (should have been caught in analysis): " + type); } - private static Optional extractOrderableDomain(ComparisonExpression.Operator comparisonOperator, Type type, Object value, boolean complement) + private static Optional extractOrderableDomain(Comparison.Operator comparisonOperator, Type type, Object value, boolean complement) { checkArgument(value != null); @@ -735,7 +735,7 @@ the Domain should consist of ranges (which do not sum to the whole ValueSet), an }; } - private static Domain extractEquatableDomain(ComparisonExpression.Operator comparisonOperator, Type type, Object value, boolean complement) + private static Domain extractEquatableDomain(Comparison.Operator comparisonOperator, Type type, Object value, boolean complement) { checkArgument(value != null); return switch (comparisonOperator) { @@ -750,7 +750,7 @@ private Optional coerceComparisonWithRounding( Type symbolExpressionType, Expression symbolExpression, NullableValue nullableValue, - ComparisonExpression.Operator comparisonOperator) + Comparison.Operator comparisonOperator) { requireNonNull(nullableValue, "nullableValue is null"); if (nullableValue.isNull()) { @@ -766,7 +766,7 @@ private Optional coerceComparisonWithRounding( ErrorCode errorCode = e.getErrorCode(); if (INVALID_CAST_ARGUMENT.toErrorCode().equals(errorCode)) { // There's no such value at symbolExpressionType - return Optional.of(FALSE_LITERAL); + return Optional.of(FALSE); } throw e; } @@ -779,7 +779,7 @@ private Expression rewriteComparisonExpression( Type valueType, Object originalValue, Object coercedValue, - ComparisonExpression.Operator comparisonOperator) + Comparison.Operator comparisonOperator) { int originalComparedToCoerced = compareOriginalValueToCoerced(valueType, originalValue, symbolExpressionType, coercedValue); boolean coercedValueIsEqualToOriginal = originalComparedToCoerced == 0; @@ -790,47 +790,47 @@ private Expression rewriteComparisonExpression( return switch (comparisonOperator) { case GREATER_THAN_OR_EQUAL, GREATER_THAN -> { if (coercedValueIsGreaterThanOriginal) { - yield new ComparisonExpression(GREATER_THAN_OR_EQUAL, symbolExpression, coercedLiteral); + yield new Comparison(GREATER_THAN_OR_EQUAL, symbolExpression, coercedLiteral); } if (coercedValueIsEqualToOriginal) { - yield new ComparisonExpression(comparisonOperator, symbolExpression, coercedLiteral); + yield new Comparison(comparisonOperator, symbolExpression, coercedLiteral); } if (coercedValueIsLessThanOriginal) { - yield new ComparisonExpression(GREATER_THAN, symbolExpression, coercedLiteral); + yield new Comparison(GREATER_THAN, symbolExpression, coercedLiteral); } throw new AssertionError("Unreachable"); } case LESS_THAN_OR_EQUAL, LESS_THAN -> { if (coercedValueIsLessThanOriginal) { - yield new ComparisonExpression(LESS_THAN_OR_EQUAL, symbolExpression, coercedLiteral); + yield new Comparison(LESS_THAN_OR_EQUAL, symbolExpression, coercedLiteral); } if (coercedValueIsEqualToOriginal) { - yield new ComparisonExpression(comparisonOperator, symbolExpression, coercedLiteral); + yield new Comparison(comparisonOperator, symbolExpression, coercedLiteral); } if (coercedValueIsGreaterThanOriginal) { - yield new ComparisonExpression(LESS_THAN, symbolExpression, coercedLiteral); + yield new Comparison(LESS_THAN, symbolExpression, coercedLiteral); } throw new AssertionError("Unreachable"); } case EQUAL -> { if (coercedValueIsEqualToOriginal) { - yield new ComparisonExpression(EQUAL, symbolExpression, coercedLiteral); + yield new Comparison(EQUAL, symbolExpression, coercedLiteral); } // Return something that is false for all non-null values - yield and(new ComparisonExpression(GREATER_THAN, symbolExpression, coercedLiteral), - new ComparisonExpression(LESS_THAN, symbolExpression, coercedLiteral)); + yield and(new Comparison(GREATER_THAN, symbolExpression, coercedLiteral), + new Comparison(LESS_THAN, symbolExpression, coercedLiteral)); } case NOT_EQUAL -> { if (coercedValueIsEqualToOriginal) { - yield new ComparisonExpression(comparisonOperator, symbolExpression, coercedLiteral); + yield new Comparison(comparisonOperator, symbolExpression, coercedLiteral); } // Return something that is true for all non-null values - yield or(new ComparisonExpression(EQUAL, symbolExpression, coercedLiteral), - new ComparisonExpression(NOT_EQUAL, symbolExpression, coercedLiteral)); + yield or(new Comparison(EQUAL, symbolExpression, coercedLiteral), + new Comparison(NOT_EQUAL, symbolExpression, coercedLiteral)); } case IS_DISTINCT_FROM -> coercedValueIsEqualToOriginal ? - new ComparisonExpression(comparisonOperator, symbolExpression, coercedLiteral) : - TRUE_LITERAL; + new Comparison(comparisonOperator, symbolExpression, coercedLiteral) : + TRUE; }; } @@ -868,9 +868,9 @@ private int compareOriginalValueToCoerced(Type originalValueType, Object origina } @Override - protected ExtractionResult visitInPredicate(InPredicate node, Boolean complement) + protected ExtractionResult visitIn(In node, Boolean complement) { - checkState(!node.getValueList().isEmpty(), "InListExpression should never be empty"); + checkState(!node.valueList().isEmpty(), "InListExpression should never be empty"); Optional directExtractionResult = processSimpleInPredicate(node, complement); if (directExtractionResult.isPresent()) { @@ -878,8 +878,8 @@ protected ExtractionResult visitInPredicate(InPredicate node, Boolean complement } ImmutableList.Builder disjuncts = ImmutableList.builder(); - for (Expression expression : node.getValueList()) { - disjuncts.add(new ComparisonExpression(EQUAL, node.getValue(), expression)); + for (Expression expression : node.valueList()) { + disjuncts.add(new Comparison(EQUAL, node.value(), expression)); } ExtractionResult extractionResult = process(or(disjuncts.build()), complement); @@ -887,49 +887,49 @@ protected ExtractionResult visitInPredicate(InPredicate node, Boolean complement if (extractionResult.tupleDomain.isAll()) { Expression originalPredicate = node; if (complement) { - originalPredicate = new NotExpression(originalPredicate); + originalPredicate = new Not(originalPredicate); } return new ExtractionResult(extractionResult.tupleDomain, originalPredicate); } return extractionResult; } - private Optional processSimpleInPredicate(InPredicate node, Boolean complement) + private Optional processSimpleInPredicate(In node, Boolean complement) { - if (!(node.getValue() instanceof SymbolReference)) { + if (!(node.value() instanceof Reference)) { return Optional.empty(); } - Symbol symbol = Symbol.from(node.getValue()); + Symbol symbol = Symbol.from(node.value()); Type type = node.value().type(); - List inValues = new ArrayList<>(node.getValueList().size()); + List inValues = new ArrayList<>(node.valueList().size()); List excludedExpressions = new ArrayList<>(); - for (Expression expression : node.getValueList()) { + for (Expression expression : node.valueList()) { if (expression instanceof Constant constant) { - if (constant.getValue() == null) { + if (constant.value() == null) { if (complement) { // NOT IN is equivalent to NOT(s eq v1) AND NOT(s eq v2). When any right value is NULL, the comparison result is NULL, so AND's result can be at most // NULL (effectively false in predicate context) - return Optional.of(new ExtractionResult(TupleDomain.none(), TRUE_LITERAL)); + return Optional.of(new ExtractionResult(TupleDomain.none(), TRUE)); } // in case of IN, NULL on the right results with NULL comparison result (effectively false in predicate context), so can be ignored, as the // comparison results are OR-ed } else if (type instanceof RealType || type instanceof DoubleType) { // NaN can be ignored: it always compares to false, as if it was not among IN's values - if (!isFloatingPointNaN(type, constant.getValue())) { + if (!isFloatingPointNaN(type, constant.value())) { if (complement) { // in case of NOT IN with floating point, the NaN on the left passes the test (unless a NULL is found, and we exited earlier) // but this cannot currently be described with a Domain other than Domain.all excludedExpressions.add(expression); } else { - inValues.add(constant.getValue()); + inValues.add(constant.value()); } } } else { - inValues.add(constant.getValue()); + inValues.add(constant.value()); } } else { @@ -950,33 +950,33 @@ else if (type instanceof RealType || type instanceof DoubleType) { Expression remainingExpression; if (excludedExpressions.isEmpty()) { - remainingExpression = TRUE_LITERAL; + remainingExpression = TRUE; } else if (excludedExpressions.size() == 1) { - remainingExpression = new NotExpression(new ComparisonExpression(EQUAL, node.getValue(), getOnlyElement(excludedExpressions))); + remainingExpression = new Not(new Comparison(EQUAL, node.value(), getOnlyElement(excludedExpressions))); } else { - remainingExpression = new NotExpression(new InPredicate(node.getValue(), excludedExpressions)); + remainingExpression = new Not(new In(node.value(), excludedExpressions)); } return Optional.of(new ExtractionResult(tupleDomain, remainingExpression)); } @Override - protected ExtractionResult visitBetweenPredicate(BetweenPredicate node, Boolean complement) + protected ExtractionResult visitBetween(Between node, Boolean complement) { // Re-write as two comparison expressions return process(and( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()), - new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax())), complement); + new Comparison(GREATER_THAN_OR_EQUAL, node.value(), node.min()), + new Comparison(LESS_THAN_OR_EQUAL, node.value(), node.max())), complement); } - private Optional tryVisitLikeFunction(FunctionCall node, Boolean complement) + private Optional tryVisitLikeFunction(Call node, Boolean complement) { - Expression value = node.getArguments().get(0); - Expression patternArgument = node.getArguments().get(1); + Expression value = node.arguments().get(0); + Expression patternArgument = node.arguments().get(1); - if (!(value instanceof SymbolReference)) { + if (!(value instanceof Reference)) { // LIKE not on a symbol return Optional.empty(); } @@ -989,12 +989,12 @@ private Optional tryVisitLikeFunction(FunctionCall node, Boole Symbol symbol = Symbol.from(value); - if (node.getArguments().size() > 2 || !(patternArgument instanceof Constant patternConstant)) { + if (node.arguments().size() > 2 || !(patternArgument instanceof Constant patternConstant)) { // dynamic pattern or escape return Optional.empty(); } - LikePattern matcher = (LikePattern) patternConstant.getValue(); + LikePattern matcher = (LikePattern) patternConstant.value(); Slice pattern = utf8Slice(matcher.getPattern()); Optional escape = matcher.getEscape() @@ -1014,7 +1014,7 @@ private Optional tryVisitLikeFunction(FunctionCall node, Boole valueSet = ValueSet.none(type); } Domain domain = Domain.create(complementIfNecessary(valueSet, complement), false); - return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), TRUE_LITERAL)); + return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), TRUE)); } if (complement || patternConstantPrefixBytes == 0) { @@ -1027,9 +1027,9 @@ private Optional tryVisitLikeFunction(FunctionCall node, Boole } @Override - protected ExtractionResult visitFunctionCall(FunctionCall node, Boolean complement) + protected ExtractionResult visitCall(Call node, Boolean complement) { - CatalogSchemaFunctionName name = node.getFunction().getName(); + CatalogSchemaFunctionName name = node.function().getName(); if (name.equals(builtinFunctionName("starts_with"))) { Optional result = tryVisitStartsWithFunction(node, complement); if (result.isPresent()) { @@ -1045,21 +1045,21 @@ else if (name.equals(builtinFunctionName(LIKE_FUNCTION_NAME))) { return visitExpression(node, complement); } - private Optional tryVisitStartsWithFunction(FunctionCall node, Boolean complement) + private Optional tryVisitStartsWithFunction(Call node, Boolean complement) { - List args = node.getArguments(); + List args = node.arguments(); if (args.size() != 2) { return Optional.empty(); } Expression target = args.get(0); - if (!(target instanceof SymbolReference)) { + if (!(target instanceof Reference)) { // Target is not a symbol return Optional.empty(); } Expression prefix = args.get(1); - if (!(prefix instanceof Constant literal && literal.getType().equals(VarcharType.VARCHAR))) { + if (!(prefix instanceof Constant literal && literal.type().equals(VarcharType.VARCHAR))) { // dynamic pattern return Optional.empty(); } @@ -1074,7 +1074,7 @@ private Optional tryVisitStartsWithFunction(FunctionCall node, } Symbol symbol = Symbol.from(target); - Slice constantPrefix = (Slice) ((Constant) prefix).getValue(); + Slice constantPrefix = (Slice) ((Constant) prefix).value(); return createRangeDomain(type, constantPrefix).map(domain -> new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), node)); } @@ -1104,31 +1104,31 @@ private Optional createRangeDomain(Type type, Slice constantPrefix) } @Override - protected ExtractionResult visitIsNullPredicate(IsNullPredicate node, Boolean complement) + protected ExtractionResult visitIsNull(IsNull node, Boolean complement) { - if (!(node.getValue() instanceof SymbolReference)) { - return super.visitIsNullPredicate(node, complement); + if (!(node.value() instanceof Reference)) { + return super.visitIsNull(node, complement); } - Symbol symbol = Symbol.from(node.getValue()); + Symbol symbol = Symbol.from(node.value()); Type columnType = symbol.getType(); Domain domain = complementIfNecessary(Domain.onlyNull(columnType), complement); return new ExtractionResult( TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), - TRUE_LITERAL); + TRUE); } @Override protected ExtractionResult visitConstant(Constant node, Boolean complement) { - if (node.getValue() == null) { - return new ExtractionResult(TupleDomain.none(), TRUE_LITERAL); + if (node.value() == null) { + return new ExtractionResult(TupleDomain.none(), TRUE); } - if (node.getType().equals(BOOLEAN)) { - boolean value = (boolean) node.getValue(); + if (node.type().equals(BOOLEAN)) { + boolean value = (boolean) node.value(); value = complement != value; - return new ExtractionResult(value ? TupleDomain.all() : TupleDomain.none(), TRUE_LITERAL); + return new ExtractionResult(value ? TupleDomain.all() : TupleDomain.none(), TRUE); } return super.visitConstant(node, complement); @@ -1138,10 +1138,10 @@ protected ExtractionResult visitConstant(Constant node, Boolean complement) private static class NormalizedSimpleComparison { private final Expression symbolExpression; - private final ComparisonExpression.Operator comparisonOperator; + private final Comparison.Operator comparisonOperator; private final NullableValue value; - public NormalizedSimpleComparison(Expression symbolExpression, ComparisonExpression.Operator comparisonOperator, NullableValue value) + public NormalizedSimpleComparison(Expression symbolExpression, Comparison.Operator comparisonOperator, NullableValue value) { this.symbolExpression = requireNonNull(symbolExpression, "symbolExpression is null"); this.comparisonOperator = requireNonNull(comparisonOperator, "comparisonOperator is null"); @@ -1153,7 +1153,7 @@ public Expression getSymbolExpression() return symbolExpression; } - public ComparisonExpression.Operator getComparisonOperator() + public Comparison.Operator getComparisonOperator() { return comparisonOperator; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java index 685532bf5748..06ca2fa3e5b3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java @@ -29,10 +29,10 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.DistinctLimitNode; @@ -69,8 +69,8 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.spi.type.TypeUtils.isFloatingPointNaN; import static io.trino.spi.type.TypeUtils.readNativeValue; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.expressionOrNullSymbols; import static io.trino.sql.ir.IrUtils.extractConjuncts; @@ -89,11 +89,11 @@ public class EffectivePredicateExtractor private static final Function, Expression> ENTRY_TO_EQUALITY = entry -> { - SymbolReference reference = entry.getKey().toSymbolReference(); + Reference reference = entry.getKey().toSymbolReference(); Expression expression = entry.getValue(); // TODO: this is not correct with respect to NULLs ('reference IS NULL' would be correct, rather than 'reference = NULL') // TODO: switch this to 'IS NOT DISTINCT FROM' syntax when EqualityInference properly supports it - return new ComparisonExpression(EQUAL, reference, expression); + return new Comparison(EQUAL, reference, expression); }; private final PlannerContext plannerContext; @@ -133,7 +133,7 @@ public Visitor(DomainTranslator domainTranslator, PlannerContext plannerContext, @Override protected Expression visitPlan(PlanNode node, Void context) { - return TRUE_LITERAL; + return TRUE; } @Override @@ -145,7 +145,7 @@ public Expression visitAggregation(AggregationNode node, Void context) // Therefore, we can't say anything about the effective predicate of the // output of such an aggregation. if (node.getGroupingKeys().isEmpty()) { - return TRUE_LITERAL; + return TRUE; } Expression underlyingPredicate = node.getSource().accept(this, context); @@ -170,7 +170,7 @@ public Expression visitFilter(FilterNode node, Void context) public Expression visitExchange(ExchangeNode node, Void context) { return deriveCommonPredicates(node, source -> { - Map mappings = new HashMap<>(); + Map mappings = new HashMap<>(); for (int i = 0; i < node.getInputs().get(source).size(); i++) { mappings.put( node.getOutputSymbols().get(i), @@ -285,7 +285,7 @@ public Expression visitUnion(UnionNode node, Void context) @Override public Expression visitUnnest(UnnestNode node, Void context) { - return TRUE_LITERAL; + return TRUE; } @Override @@ -303,7 +303,7 @@ public Expression visitJoin(JoinNode node, Void context) .add(leftPredicate) .add(rightPredicate) .add(combineConjuncts(joinConjuncts)) - .add(node.getFilter().orElse(TRUE_LITERAL)) + .add(node.getFilter().orElse(TRUE)) .build()), node.getOutputSymbols()); case LEFT -> combineConjuncts(ImmutableList.builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) @@ -327,7 +327,7 @@ public Expression visitJoin(JoinNode node, Void context) public Expression visitValues(ValuesNode node, Void context) { if (node.getOutputSymbols().isEmpty()) { - return TRUE_LITERAL; + return TRUE; } // for each row of Values, get all expressions that will be evaluated: @@ -347,7 +347,7 @@ public Expression visitValues(ValuesNode node, Void context) for (Expression row : node.getRows().get()) { if (row instanceof Row) { for (int i = 0; i < node.getOutputSymbols().size(); i++) { - Expression value = ((Row) row).getItems().get(i); + Expression value = ((Row) row).items().get(i); if (!DeterminismEvaluator.isDeterministic(value)) { nonDeterministic[i] = true; } @@ -355,7 +355,7 @@ public Expression visitValues(ValuesNode node, Void context) IrExpressionInterpreter interpreter = new IrExpressionInterpreter(value, plannerContext, session); Object item = interpreter.optimize(NoOpSymbolResolver.INSTANCE); if (item instanceof Expression) { - return TRUE_LITERAL; + return TRUE; } if (item == null) { hasNull[i] = true; @@ -363,12 +363,12 @@ public Expression visitValues(ValuesNode node, Void context) else { Type type = node.getOutputSymbols().get(i).getType(); if (!type.isComparable() && !type.isOrderable()) { - return TRUE_LITERAL; + return TRUE; } if (hasNestedNulls(type, item)) { // Workaround solution to deal with array and row comparisons don't support null elements currently. // TODO: remove when comparisons are fixed - return TRUE_LITERAL; + return TRUE; } if (isFloatingPointNaN(type, item)) { hasNaN[i] = true; @@ -380,12 +380,12 @@ public Expression visitValues(ValuesNode node, Void context) } else { if (!DeterminismEvaluator.isDeterministic(row)) { - return TRUE_LITERAL; + return TRUE; } IrExpressionInterpreter interpreter = new IrExpressionInterpreter(row, plannerContext, session); Object evaluated = interpreter.optimize(NoOpSymbolResolver.INSTANCE); if (evaluated instanceof Expression) { - return TRUE_LITERAL; + return TRUE; } SqlRow sqlRow = (SqlRow) evaluated; int rawIndex = sqlRow.getRawIndex(); @@ -398,12 +398,12 @@ public Expression visitValues(ValuesNode node, Void context) } else { if (!type.isComparable() && !type.isOrderable()) { - return TRUE_LITERAL; + return TRUE; } if (hasNestedNulls(type, item)) { // Workaround solution to deal with array and row comparisons don't support null elements currently. // TODO: remove when comparisons are fixed - return TRUE_LITERAL; + return TRUE; } if (isFloatingPointNaN(type, item)) { hasNaN[i] = true; @@ -496,7 +496,7 @@ private Iterable pullNullableConjunctsThroughOuterJoin(List pullExpressionThroughSymbols(expression, outputSymbols)) - .map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) + .map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? TRUE : expression) .map(expressionOrNullSymbols(nullSymbolScopes)) .collect(toImmutableList()); } @@ -526,7 +526,7 @@ public Expression visitSpatialJoin(SpatialJoinNode node, Void context) }; } - private Expression deriveCommonPredicates(PlanNode node, Function>> mapping) + private Expression deriveCommonPredicates(PlanNode node, Function>> mapping) { // Find the predicates that can be pulled up from each source List> sourceOutputConjuncts = new ArrayList<>(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java index c75b20a5449d..c37fbb97c4f5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java @@ -19,10 +19,10 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.Multimap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrUtils; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.util.DisjointSet; import java.util.ArrayList; @@ -73,9 +73,9 @@ public EqualityInference(Collection expressions) .flatMap(expression -> extractConjuncts(expression).stream()) .filter(expression -> isInferenceCandidate(expression)) .forEach(expression -> { - ComparisonExpression comparison = (ComparisonExpression) expression; - Expression expression1 = comparison.getLeft(); - Expression expression2 = comparison.getRight(); + Comparison comparison = (Comparison) expression; + Expression expression1 = comparison.left(); + Expression expression2 = comparison.right(); equalities.findAndUnion(expression1, expression2); }); @@ -203,14 +203,14 @@ public EqualityPartition generateEqualitiesPartitionedBy(Set scope) if (scopeExpressions.size() >= 2) { scopeExpressions.stream() .filter(expression -> !expression.equals(matchingCanonical)) - .map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingCanonical, expression)) + .map(expression -> new Comparison(Comparison.Operator.EQUAL, matchingCanonical, expression)) .forEach(scopeEqualities::add); } Expression complementCanonical = getCanonical(scopeComplementExpressions.stream()); if (scopeComplementExpressions.size() >= 2) { scopeComplementExpressions.stream() .filter(expression -> !expression.equals(complementCanonical)) - .map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, complementCanonical, expression)) + .map(expression -> new Comparison(Comparison.Operator.EQUAL, complementCanonical, expression)) .forEach(scopeComplementEqualities::add); } @@ -224,7 +224,7 @@ public EqualityPartition generateEqualitiesPartitionedBy(Set scope) .filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, scope::contains, false) == null) .min(canonicalComparator); if (matchingConnecting.isPresent() && complementConnecting.isPresent() && !matchingConnecting.equals(complementConnecting)) { - scopeStraddlingEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get())); + scopeStraddlingEqualities.add(new Comparison(Comparison.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get())); } // Compile the scope straddling equality expressions. @@ -243,7 +243,7 @@ else if (complementCanonical != null) { if (connectingCanonical != null) { straddlingExpressions.stream() .filter(expression -> !expression.equals(connectingCanonical)) - .map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, connectingCanonical, expression)) + .map(expression -> new Comparison(Comparison.Operator.EQUAL, connectingCanonical, expression)) .forEach(scopeStraddlingEqualities::add); } } @@ -256,12 +256,12 @@ else if (complementCanonical != null) { */ public static boolean isInferenceCandidate(Expression expression) { - if (expression instanceof ComparisonExpression comparison && + if (expression instanceof Comparison comparison && isDeterministic(expression) && !mayReturnNullOnNonNullInput(expression)) { - if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) { + if (comparison.operator() == Comparison.Operator.EQUAL) { // We should only consider equalities that have distinct left and right components - return !comparison.getLeft().equals(comparison.getRight()); + return !comparison.left().equals(comparison.right()); } } return false; @@ -323,9 +323,9 @@ Expression getScopedCanonical(Expression expression, Predicate symbolSco } Collection equivalences = equalitySets.get(canonicalIndex); - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { boolean inScope = equivalences.stream() - .filter(SymbolReference.class::isInstance) + .filter(Reference.class::isInstance) .map(Symbol::from) .anyMatch(symbolScope); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionSymbolInliner.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionSymbolInliner.java index 3a4041c91a28..f64992e80e12 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionSymbolInliner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionSymbolInliner.java @@ -18,8 +18,8 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import java.util.Map; import java.util.function.Function; @@ -57,19 +57,19 @@ private class Visitor private final Multiset excludedNames = HashMultiset.create(); @Override - public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteReference(Reference node, Void context, ExpressionTreeRewriter treeRewriter) { - if (excludedNames.contains(node.getName())) { + if (excludedNames.contains(node.name())) { return node; } Expression expression = mapping.apply(Symbol.from(node)); - checkState(expression != null, "Cannot resolve symbol %s", node.getName()); + checkState(expression != null, "Cannot resolve symbol %s", node.name()); return expression; } @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLambda(Lambda node, Void context, ExpressionTreeRewriter treeRewriter) { excludedNames.addAll(node.arguments().stream() .map(Symbol::getName) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java index e3858145af79..f0dd17a61d9a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java @@ -23,10 +23,10 @@ import io.trino.sql.analyzer.FieldId; import io.trino.sql.analyzer.RelationId; import io.trino.sql.analyzer.ResolvedField; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.Subscript; import io.trino.sql.tree.GroupingOperation; import io.trino.sql.tree.NodeRef; @@ -40,7 +40,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static java.util.Objects.requireNonNull; public final class GroupingOperationRewriter @@ -91,13 +91,13 @@ public static Expression rewriteGroupingOperation( .collect(toImmutableList()); // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1 - return new SubscriptExpression( + return new Subscript( type, BuiltinFunctionCallBuilder.resolve(metadata) .setName(ArrayConstructor.NAME) .setArguments(Collections.nCopies(groupingResults.size(), type), groupingResults) .build(), - new ArithmeticBinaryExpression( + new Arithmetic( metadata.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)), ADD, groupIdSymbol.get().toSymbolReference(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index dfbd9cc3ff14..2806624d034e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -32,29 +32,29 @@ import io.trino.spi.type.Type; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.BindExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Bind; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.ComparisonExpression.Operator; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Comparison.Operator; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; +import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.type.FunctionType; import io.trino.type.TypeCoercion; @@ -179,7 +179,7 @@ private Object processWithExceptionHandling(Expression expression, Object contex } @Override - protected Object visitSymbolReference(SymbolReference node, Object context) + protected Object visitReference(Reference node, Object context) { return ((SymbolResolver) context).getValue(Symbol.from(node)); } @@ -187,29 +187,29 @@ protected Object visitSymbolReference(SymbolReference node, Object context) @Override protected Object visitConstant(Constant node, Object context) { - return node.getValue(); + return node.value(); } @Override - protected Object visitIsNullPredicate(IsNullPredicate node, Object context) + protected Object visitIsNull(IsNull node, Object context) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.value(), context); if (value instanceof Expression) { - return new IsNullPredicate(toExpression(value, node.getValue().type())); + return new IsNull(toExpression(value, node.value().type())); } return value == null; } @Override - protected Object visitSearchedCaseExpression(SearchedCaseExpression node, Object context) + protected Object visitCase(Case node, Object context) { Object newDefault = null; boolean foundNewDefault = false; List whenClauses = new ArrayList<>(); - for (WhenClause whenClause : node.getWhenClauses()) { + for (WhenClause whenClause : node.whenClauses()) { Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context); if (whenOperand instanceof Expression) { @@ -231,7 +231,7 @@ else if (Boolean.TRUE.equals(whenOperand)) { defaultResult = newDefault; } else { - defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context); + defaultResult = processWithExceptionHandling(node.defaultValue().orElse(null), context); } if (whenClauses.isEmpty()) { @@ -240,25 +240,25 @@ else if (Boolean.TRUE.equals(whenOperand)) { Expression defaultExpression; defaultExpression = defaultResult == null ? null : toExpression(defaultResult, ((Expression) node).type()); - return new SearchedCaseExpression(whenClauses, Optional.ofNullable(defaultExpression)); + return new Case(whenClauses, Optional.ofNullable(defaultExpression)); } @Override - protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object context) + protected Object visitSwitch(Switch node, Object context) { - Object operand = processWithExceptionHandling(node.getOperand(), context); - Type operandType = node.getOperand().type(); + Object operand = processWithExceptionHandling(node.operand(), context); + Type operandType = node.operand().type(); // if operand is null, return defaultValue if (operand == null) { - return processWithExceptionHandling(node.getDefaultValue().orElse(null), context); + return processWithExceptionHandling(node.defaultValue().orElse(null), context); } Object newDefault = null; boolean foundNewDefault = false; List whenClauses = new ArrayList<>(); - for (WhenClause whenClause : node.getWhenClauses()) { + for (WhenClause whenClause : node.whenClauses()) { Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context); if (whenOperand instanceof Expression || operand instanceof Expression) { @@ -282,7 +282,7 @@ protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object con defaultResult = newDefault; } else { - defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context); + defaultResult = processWithExceptionHandling(node.defaultValue().orElse(null), context); } if (whenClauses.isEmpty()) { @@ -291,7 +291,7 @@ protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object con Expression defaultExpression; defaultExpression = defaultResult == null ? null : toExpression(defaultResult, ((Expression) node).type()); - return new SimpleCaseExpression(toExpression(operand, node.getOperand().type()), whenClauses, Optional.ofNullable(defaultExpression)); + return new Switch(toExpression(operand, node.operand().type()), whenClauses, Optional.ofNullable(defaultExpression)); } private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2) @@ -300,7 +300,7 @@ private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2 } @Override - protected Object visitCoalesceExpression(CoalesceExpression node, Object context) + protected Object visitCoalesce(Coalesce node, Object context) { List newOperands = processOperands(node, context); if (newOperands.isEmpty()) { @@ -309,26 +309,26 @@ protected Object visitCoalesceExpression(CoalesceExpression node, Object context if (newOperands.size() == 1) { return getOnlyElement(newOperands); } - return new CoalesceExpression(newOperands.stream() + return new Coalesce(newOperands.stream() .map(value -> toExpression(value, ((Expression) node).type())) .collect(toImmutableList())); } - private List processOperands(CoalesceExpression node, Object context) + private List processOperands(Coalesce node, Object context) { List newOperands = new ArrayList<>(); Set uniqueNewOperands = new HashSet<>(); - for (Expression operand : node.getOperands()) { + for (Expression operand : node.operands()) { Object value = processWithExceptionHandling(operand, context); - if (value instanceof CoalesceExpression) { + if (value instanceof Coalesce) { // The nested CoalesceExpression was recursively processed. It does not contain null. - for (Expression nestedOperand : ((CoalesceExpression) value).getOperands()) { + for (Expression nestedOperand : ((Coalesce) value).operands()) { // Skip duplicates unless they are non-deterministic. if (!isDeterministic(nestedOperand) || uniqueNewOperands.add(nestedOperand)) { newOperands.add(nestedOperand); } // This operand can be evaluated to a non-null value. Remaining operands can be skipped. - if (nestedOperand instanceof Constant constant && constant.getValue() != null) { + if (nestedOperand instanceof Constant constant && constant.value() != null) { return newOperands; } } @@ -349,17 +349,17 @@ else if (value != null) { } @Override - protected Object visitInPredicate(InPredicate node, Object context) + protected Object visitIn(In node, Object context) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.value(), context); - List valueList = node.getValueList(); + List valueList = node.valueList(); // `NULL IN ()` would be false, but InListExpression cannot be empty by construction if (value == null) { return null; } - Type type = node.getValue().type(); + Type type = node.value().type(); if (!(value instanceof Expression) && !(type instanceof ArrayType) && // equals/hashcode doesn't work for complex types that may contain nulls !(type instanceof MapType) && @@ -373,7 +373,7 @@ protected Object visitInPredicate(InPredicate node, Object context) boolean nonNullConstants = valueList.stream().allMatch(Constant.class::isInstance) && valueList.stream() .map(Constant.class::cast) - .map(Constant::getValue) + .map(Constant::value) .noneMatch(Objects::isNull); if (nonNullConstants) { Set objectSet = valueList.stream().map(expression -> processWithExceptionHandling(expression, context)).collect(Collectors.toSet()); @@ -397,7 +397,7 @@ protected Object visitInPredicate(InPredicate node, Object context) List values = new ArrayList<>(valueList.size()); List types = new ArrayList<>(valueList.size()); - ResolvedFunction equalsOperator = metadata.resolveOperator(OperatorType.EQUAL, types(node.getValue(), node.getValue())); + ResolvedFunction equalsOperator = metadata.resolveOperator(OperatorType.EQUAL, types(node.value(), node.value())); for (Expression expression : valueList) { if (value instanceof Expression && expression instanceof Constant) { // skip interpreting of literal IN term since it cannot be compared @@ -449,10 +449,10 @@ else if (!found && result) { .collect(toImmutableList()); if (simplifiedExpressionValues.size() == 1) { - return new ComparisonExpression(Operator.EQUAL, toExpression(value, type), simplifiedExpressionValues.get(0)); + return new Comparison(Operator.EQUAL, toExpression(value, type), simplifiedExpressionValues.get(0)); } - return new InPredicate(toExpression(value, type), simplifiedExpressionValues); + return new In(toExpression(value, type), simplifiedExpressionValues); } if (hasNullValue) { return null; @@ -461,21 +461,21 @@ else if (!found && result) { } @Override - protected Object visitArithmeticNegation(ArithmeticNegation node, Object context) + protected Object visitNegation(Negation node, Object context) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.value(), context); if (value == null) { return null; } if (value instanceof Expression) { - Expression valueExpression = toExpression(value, node.getValue().type()); - if (valueExpression instanceof ArithmeticNegation argument) { - return argument.getValue(); + Expression valueExpression = toExpression(value, node.value().type()); + if (valueExpression instanceof Negation argument) { + return argument.value(); } - return new ArithmeticNegation(valueExpression); + return new Negation(valueExpression); } - ResolvedFunction resolvedOperator = metadata.resolveOperator(OperatorType.NEGATION, types(node.getValue())); + ResolvedFunction resolvedOperator = metadata.resolveOperator(OperatorType.NEGATION, types(node.value())); InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), FAIL_ON_NULL, true, false); MethodHandle handle = plannerContext.getFunctionManager().getScalarFunctionImplementation(resolvedOperator, invocationConvention).getMethodHandle(); @@ -493,30 +493,30 @@ protected Object visitArithmeticNegation(ArithmeticNegation node, Object context } @Override - protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, Object context) + protected Object visitArithmetic(Arithmetic node, Object context) { - Object left = processWithExceptionHandling(node.getLeft(), context); + Object left = processWithExceptionHandling(node.left(), context); if (left == null) { return null; } - Object right = processWithExceptionHandling(node.getRight(), context); + Object right = processWithExceptionHandling(node.right(), context); if (right == null) { return null; } if (hasUnresolvedValue(left, right)) { - return new ArithmeticBinaryExpression(node.getFunction(), node.getOperator(), toExpression(left, node.getLeft().type()), toExpression(right, node.getRight().type())); + return new Arithmetic(node.function(), node.operator(), toExpression(left, node.left().type()), toExpression(right, node.right().type())); } - return functionInvoker.invoke(node.getFunction(), connectorSession, ImmutableList.of(left, right)); + return functionInvoker.invoke(node.function(), connectorSession, ImmutableList.of(left, right)); } @Override - protected Object visitComparisonExpression(ComparisonExpression node, Object context) + protected Object visitComparison(Comparison node, Object context) { - Operator operator = node.getOperator(); - Expression left = node.getLeft(); - Expression right = node.getRight(); + Operator operator = node.operator(); + Expression left = node.left(); + Expression right = node.right(); if (operator == Operator.IS_DISTINCT_FROM) { return processIsDistinctFrom(context, left, right); @@ -524,20 +524,20 @@ protected Object visitComparisonExpression(ComparisonExpression node, Object con // Execution engine does not have not equal and greater than operators, so interpret with // equal or less than, but do not flip operator in result, as many optimizers depend on // operators not flipping - if (node.getOperator() == Operator.NOT_EQUAL) { - Object result = visitComparisonExpression(flipComparison(node), context); + if (node.operator() == Operator.NOT_EQUAL) { + Object result = visitComparison(flipComparison(node), context); if (result == null) { return null; } - if (result instanceof ComparisonExpression) { - return flipComparison((ComparisonExpression) result); + if (result instanceof Comparison) { + return flipComparison((Comparison) result); } return !(Boolean) result; } - if (node.getOperator() == Operator.GREATER_THAN || node.getOperator() == Operator.GREATER_THAN_OR_EQUAL) { - Object result = visitComparisonExpression(flipComparison(node), context); - if (result instanceof ComparisonExpression) { - return flipComparison((ComparisonExpression) result); + if (node.operator() == Operator.GREATER_THAN || node.operator() == Operator.GREATER_THAN_OR_EQUAL) { + Object result = visitComparison(flipComparison(node), context); + if (result instanceof Comparison) { + return flipComparison((Comparison) result); } return result; } @@ -551,15 +551,15 @@ private Object processIsDistinctFrom(Object context, Expression leftExpression, Object right = processWithExceptionHandling(rightExpression, context); if (left == null && right instanceof Expression) { - return new NotExpression(new IsNullPredicate(((Expression) right))); + return new Not(new IsNull(((Expression) right))); } if (right == null && left instanceof Expression) { - return new NotExpression(new IsNullPredicate(((Expression) left))); + return new Not(new IsNull(((Expression) left))); } if (left instanceof Expression || right instanceof Expression) { - return new ComparisonExpression(Operator.IS_DISTINCT_FROM, toExpression(left, leftExpression.type()), toExpression(right, rightExpression.type())); + return new Comparison(Operator.IS_DISTINCT_FROM, toExpression(left, leftExpression.type()), toExpression(right, rightExpression.type())); } return invokeOperator(OperatorType.valueOf(Operator.IS_DISTINCT_FROM.name()), types(leftExpression, rightExpression), Arrays.asList(left, right)); @@ -578,50 +578,50 @@ private Object processComparisonExpression(Object context, Operator operator, Ex } if (left instanceof Expression || right instanceof Expression) { - return new ComparisonExpression(operator, toExpression(left, leftExpression.type()), toExpression(right, rightExpression.type())); + return new Comparison(operator, toExpression(left, leftExpression.type()), toExpression(right, rightExpression.type())); } return invokeOperator(OperatorType.valueOf(operator.name()), types(leftExpression, rightExpression), ImmutableList.of(left, right)); } // TODO define method contract or split into separate methods, as flip(EQUAL) is a negation, while flip(LESS_THAN) is just flipping sides - private ComparisonExpression flipComparison(ComparisonExpression comparisonExpression) - { - return switch (comparisonExpression.getOperator()) { - case EQUAL -> new ComparisonExpression(Operator.NOT_EQUAL, comparisonExpression.getLeft(), comparisonExpression.getRight()); - case NOT_EQUAL -> new ComparisonExpression(Operator.EQUAL, comparisonExpression.getLeft(), comparisonExpression.getRight()); - case LESS_THAN -> new ComparisonExpression(Operator.GREATER_THAN, comparisonExpression.getRight(), comparisonExpression.getLeft()); - case LESS_THAN_OR_EQUAL -> new ComparisonExpression(Operator.GREATER_THAN_OR_EQUAL, comparisonExpression.getRight(), comparisonExpression.getLeft()); - case GREATER_THAN -> new ComparisonExpression(Operator.LESS_THAN, comparisonExpression.getRight(), comparisonExpression.getLeft()); - case GREATER_THAN_OR_EQUAL -> new ComparisonExpression(Operator.LESS_THAN_OR_EQUAL, comparisonExpression.getRight(), comparisonExpression.getLeft()); - default -> throw new IllegalStateException("Unexpected value: " + comparisonExpression.getOperator()); + private Comparison flipComparison(Comparison comparison) + { + return switch (comparison.operator()) { + case EQUAL -> new Comparison(Operator.NOT_EQUAL, comparison.left(), comparison.right()); + case NOT_EQUAL -> new Comparison(Operator.EQUAL, comparison.left(), comparison.right()); + case LESS_THAN -> new Comparison(Operator.GREATER_THAN, comparison.right(), comparison.left()); + case LESS_THAN_OR_EQUAL -> new Comparison(Operator.GREATER_THAN_OR_EQUAL, comparison.right(), comparison.left()); + case GREATER_THAN -> new Comparison(Operator.LESS_THAN, comparison.right(), comparison.left()); + case GREATER_THAN_OR_EQUAL -> new Comparison(Operator.LESS_THAN_OR_EQUAL, comparison.right(), comparison.left()); + default -> throw new IllegalStateException("Unexpected value: " + comparison.operator()); }; } @Override - protected Object visitBetweenPredicate(BetweenPredicate node, Object context) + protected Object visitBetween(Between node, Object context) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.value(), context); if (value == null) { return null; } - Object min = processWithExceptionHandling(node.getMin(), context); - Object max = processWithExceptionHandling(node.getMax(), context); + Object min = processWithExceptionHandling(node.min(), context); + Object max = processWithExceptionHandling(node.max(), context); if (value instanceof Expression || min instanceof Expression || max instanceof Expression) { - return new BetweenPredicate( - toExpression(value, node.getValue().type()), - toExpression(min, node.getMin().type()), - toExpression(max, node.getMax().type())); + return new Between( + toExpression(value, node.value().type()), + toExpression(min, node.min().type()), + toExpression(max, node.max().type())); } Boolean greaterOrEqualToMin = null; if (min != null) { - greaterOrEqualToMin = (Boolean) invokeOperator(OperatorType.LESS_THAN_OR_EQUAL, types(node.getMin(), node.getValue()), ImmutableList.of(min, value)); + greaterOrEqualToMin = (Boolean) invokeOperator(OperatorType.LESS_THAN_OR_EQUAL, types(node.min(), node.value()), ImmutableList.of(min, value)); } Boolean lessThanOrEqualToMax = null; if (max != null) { - lessThanOrEqualToMax = (Boolean) invokeOperator(OperatorType.LESS_THAN_OR_EQUAL, types(node.getValue(), node.getMax()), ImmutableList.of(value, max)); + lessThanOrEqualToMax = (Boolean) invokeOperator(OperatorType.LESS_THAN_OR_EQUAL, types(node.value(), node.max()), ImmutableList.of(value, max)); } if (greaterOrEqualToMin == null) { @@ -634,22 +634,22 @@ protected Object visitBetweenPredicate(BetweenPredicate node, Object context) } @Override - protected Object visitNullIfExpression(NullIfExpression node, Object context) + protected Object visitNullIf(NullIf node, Object context) { - Object first = processWithExceptionHandling(node.getFirst(), context); + Object first = processWithExceptionHandling(node.first(), context); if (first == null) { return null; } - Object second = processWithExceptionHandling(node.getSecond(), context); + Object second = processWithExceptionHandling(node.second(), context); if (second == null) { return first; } - Type firstType = node.getFirst().type(); - Type secondType = node.getSecond().type(); + Type firstType = node.first().type(); + Type secondType = node.second().type(); if (hasUnresolvedValue(first, second)) { - return new NullIfExpression(toExpression(first, firstType), toExpression(second, secondType)); + return new NullIf(toExpression(first, firstType), toExpression(second, secondType)); } Type commonType = typeCoercion.getCommonSuperType(firstType, secondType).get(); @@ -672,30 +672,30 @@ protected Object visitNullIfExpression(NullIfExpression node, Object context) } @Override - protected Object visitNotExpression(NotExpression node, Object context) + protected Object visitNot(Not node, Object context) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.value(), context); if (value == null) { return null; } if (value instanceof Expression) { - return new NotExpression(toExpression(value, node.getValue().type())); + return new Not(toExpression(value, node.value().type())); } return !(Boolean) value; } @Override - protected Object visitLogicalExpression(LogicalExpression node, Object context) + protected Object visitLogical(Logical node, Object context) { List terms = new ArrayList<>(); List types = new ArrayList<>(); - for (Expression term : node.getTerms()) { + for (Expression term : node.terms()) { Object processed = processWithExceptionHandling(term, context); - switch (node.getOperator()) { + switch (node.operator()) { case AND -> { if (Boolean.FALSE.equals(processed)) { return false; @@ -718,7 +718,7 @@ protected Object visitLogicalExpression(LogicalExpression node, Object context) } if (terms.isEmpty()) { - return switch (node.getOperator()) { + return switch (node.operator()) { case AND -> true; // terms are true case OR -> false; // all terms are false }; @@ -736,22 +736,22 @@ protected Object visitLogicalExpression(LogicalExpression node, Object context) for (int i = 0; i < terms.size(); i++) { expressions.add(toExpression(terms.get(i), types.get(i))); } - return new LogicalExpression(node.getOperator(), expressions.build()); + return new Logical(node.operator(), expressions.build()); } @Override - protected Object visitFunctionCall(FunctionCall node, Object context) + protected Object visitCall(Call node, Object context) { List argumentTypes = new ArrayList<>(); List argumentValues = new ArrayList<>(); - for (Expression expression : node.getArguments()) { + for (Expression expression : node.arguments()) { Object value = processWithExceptionHandling(expression, context); Type type = expression.type(); argumentValues.add(value); argumentTypes.add(type); } - ResolvedFunction resolvedFunction = node.getFunction(); + ResolvedFunction resolvedFunction = node.function(); FunctionNullability functionNullability = resolvedFunction.getFunctionNullability(); for (int i = 0; i < argumentValues.size(); i++) { Object value = argumentValues.get(i); @@ -773,12 +773,12 @@ protected Object visitFunctionCall(FunctionCall node, Object context) } @Override - protected Object visitLambdaExpression(LambdaExpression node, Object context) + protected Object visitLambda(Lambda node, Object context) { if (optimize) { // TODO: enable optimization related to lambda expression // A mechanism to convert function type back into lambda expression need to exist to enable optimization - Object value = processWithExceptionHandling(node.getBody(), context); + Object value = processWithExceptionHandling(node.body(), context); Expression optimizedBody; // value may be null, converted to an expression by toExpression(value, type) @@ -786,14 +786,14 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context) optimizedBody = (Expression) value; } else { - Type type = node.getBody().type(); + Type type = node.body().type(); optimizedBody = toExpression(value, type); } - return new LambdaExpression(node.getArguments(), optimizedBody); + return new Lambda(node.arguments(), optimizedBody); } - Expression body = node.getBody(); - List argumentNames = node.getArguments().stream() + Expression body = node.body(); + List argumentNames = node.arguments().stream() .map(Symbol::getName) .toList(); FunctionType functionType = (FunctionType) node.type(); @@ -810,20 +810,20 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context) } @Override - protected Object visitBindExpression(BindExpression node, Object context) + protected Object visitBind(Bind node, Object context) { - List values = node.getValues().stream() + List values = node.values().stream() .map(value -> processWithExceptionHandling(value, context)) .collect(toList()); // values are nullable - Object function = processWithExceptionHandling(node.getFunction(), context); + Object function = processWithExceptionHandling(node.function(), context); if (hasUnresolvedValue(values) || hasUnresolvedValue(function)) { ImmutableList.Builder builder = ImmutableList.builder(); for (int i = 0; i < values.size(); i++) { - builder.add(toExpression(values.get(i), node.getValues().get(i).type())); + builder.add(toExpression(values.get(i), node.values().get(i).type())); } - return new BindExpression(builder.build(), (LambdaExpression) function); + return new Bind(builder.build(), (Lambda) function); } return MethodHandles.insertArguments((MethodHandle) function, 0, values.toArray()); @@ -832,15 +832,15 @@ protected Object visitBindExpression(BindExpression node, Object context) @Override public Object visitCast(Cast node, Object context) { - Object value = processWithExceptionHandling(node.getExpression(), context); - Type targetType = node.getType(); - Type sourceType = node.getExpression().type(); + Object value = processWithExceptionHandling(node.expression(), context); + Type targetType = node.type(); + Type sourceType = node.expression().type(); if (value instanceof Expression) { if (targetType.equals(sourceType)) { return value; } - return new Cast((Expression) value, node.getType(), node.isSafe()); + return new Cast((Expression) value, node.type(), node.safe()); } if (value == null) { @@ -853,7 +853,7 @@ public Object visitCast(Cast node, Object context) return functionInvoker.invoke(operator, connectorSession, ImmutableList.of(value)); } catch (RuntimeException e) { - if (node.isSafe()) { + if (node.safe()) { return null; } throw e; @@ -865,7 +865,7 @@ protected Object visitRow(Row node, Object context) { RowType rowType = (RowType) ((Expression) node).type(); List parameterTypes = rowType.getTypeParameters(); - List arguments = node.getItems(); + List arguments = node.items(); int cardinality = arguments.size(); List values = new ArrayList<>(cardinality); @@ -883,22 +883,22 @@ protected Object visitRow(Row node, Object context) } @Override - protected Object visitSubscriptExpression(SubscriptExpression node, Object context) + protected Object visitSubscript(Subscript node, Object context) { - Object base = processWithExceptionHandling(node.getBase(), context); + Object base = processWithExceptionHandling(node.base(), context); if (base == null) { return null; } - Object index = processWithExceptionHandling(node.getIndex(), context); + Object index = processWithExceptionHandling(node.index(), context); if (index == null) { return null; } - if ((index instanceof Long) && isArray(node.getBase().type())) { + if ((index instanceof Long) && isArray(node.base().type())) { ArraySubscriptOperator.checkArrayIndex((Long) index); } if (hasUnresolvedValue(base, index)) { - return new SubscriptExpression(node.type(), toExpression(base, node.getBase().type()), toExpression(index, node.getIndex().type())); + return new Subscript(node.type(), toExpression(base, node.base().type()), toExpression(index, node.index().type())); } // Subscript on Row hasn't got a dedicated operator. It is interpreted by hand. @@ -907,12 +907,12 @@ protected Object visitSubscriptExpression(SubscriptExpression node, Object conte if (fieldIndex < 0 || fieldIndex >= row.getFieldCount()) { throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "ROW index out of bounds: " + (fieldIndex + 1)); } - Type returnType = node.getBase().type().getTypeParameters().get(fieldIndex); + Type returnType = node.base().type().getTypeParameters().get(fieldIndex); return readNativeValue(returnType, row.getRawFieldBlock(fieldIndex), row.getRawIndex()); } // Subscript on Array or Map is interpreted using operator. - return invokeOperator(OperatorType.SUBSCRIPT, types(node.getBase(), node.getIndex()), ImmutableList.of(base, index)); + return invokeOperator(OperatorType.SUBSCRIPT, types(node.base(), node.index()), ImmutableList.of(base, index)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 29b4a046534f..e47c533803b4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -188,11 +188,11 @@ import io.trino.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.optimizations.IndexJoinOptimizer; import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; @@ -345,9 +345,9 @@ import static io.trino.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; @@ -1120,7 +1120,7 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext ResolvedFunction resolvedFunction = function.getResolvedFunction(); ImmutableList.Builder arguments = ImmutableList.builder(); for (Expression argument : function.getArguments()) { - if (!(argument instanceof LambdaExpression)) { + if (!(argument instanceof Lambda)) { Symbol argumentSymbol = Symbol.from(argument); arguments.add(source.getLayout().get(argumentSymbol)); } @@ -1129,16 +1129,16 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext WindowFunctionSupplier windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction); Type type = resolvedFunction.getSignature().getReturnType(); - List lambdaExpressions = function.getArguments().stream() - .filter(LambdaExpression.class::isInstance) - .map(LambdaExpression.class::cast) + List lambdas = function.getArguments().stream() + .filter(Lambda.class::isInstance) + .map(Lambda.class::cast) .collect(toImmutableList()); List functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream() .filter(FunctionType.class::isInstance) .map(FunctionType.class::cast) .collect(toImmutableList()); - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, windowFunctionSupplier.getLambdaInterfaces(), functionTypes); + List> lambdaProviders = makeLambdaProviders(lambdas, windowFunctionSupplier.getLambdaInterfaces(), functionTypes); windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), lambdaProviders, arguments.build())); windowFunctionOutputSymbolsBuilder.add(symbol); } @@ -1261,7 +1261,7 @@ public PhysicalOperation visitPatternRecognition(PatternRecognitionNode node, Lo ResolvedFunction resolvedFunction = function.getResolvedFunction(); ImmutableList.Builder arguments = ImmutableList.builder(); for (Expression argument : function.getArguments()) { - if (!(argument instanceof LambdaExpression)) { + if (!(argument instanceof Lambda)) { Symbol argumentSymbol = Symbol.from(argument); arguments.add(source.getLayout().get(argumentSymbol)); } @@ -1269,16 +1269,16 @@ public PhysicalOperation visitPatternRecognition(PatternRecognitionNode node, Lo WindowFunctionSupplier windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction); Type type = resolvedFunction.getSignature().getReturnType(); - List lambdaExpressions = function.getArguments().stream() - .filter(LambdaExpression.class::isInstance) - .map(LambdaExpression.class::cast) + List lambdas = function.getArguments().stream() + .filter(Lambda.class::isInstance) + .map(Lambda.class::cast) .collect(toImmutableList()); List functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream() .filter(FunctionType.class::isInstance) .map(FunctionType.class::cast) .collect(toImmutableList()); - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, windowFunctionSupplier.getLambdaInterfaces(), functionTypes); + List> lambdaProviders = makeLambdaProviders(lambdas, windowFunctionSupplier.getLambdaInterfaces(), functionTypes); windowFunctionsBuilder.add(window(windowFunctionSupplier, type, function.isIgnoreNulls(), lambdaProviders, arguments.build())); } @@ -1512,12 +1512,12 @@ private ValueAccessors preparePhysicalValuePointers( builder.add(new SimpleEntry<>(pointer.getArguments().get(i), signatureTypes.get(i))); } Map>> arguments = builder.build().stream() - .collect(partitioningBy(entry -> entry.getKey() instanceof LambdaExpression)); + .collect(partitioningBy(entry -> entry.getKey() instanceof Lambda)); // handle lambda arguments - List lambdaExpressions = arguments.get(true).stream() + List lambdas = arguments.get(true).stream() .map(Map.Entry::getKey) - .map(LambdaExpression.class::cast) + .map(Lambda.class::cast) .collect(toImmutableList()); List functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream() @@ -1526,7 +1526,7 @@ private ValueAccessors preparePhysicalValuePointers( .collect(toImmutableList()); // TODO when we support lambda arguments: lambda cannot have runtime-evaluated symbols -- add check in the Analyzer - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationImplementation.getLambdaInterfaces(), functionTypes); + List> lambdaProviders = makeLambdaProviders(lambdas, aggregationImplementation.getLambdaInterfaces(), functionTypes); // handle non-lambda arguments List valueChannels = new ArrayList<>(); @@ -1539,7 +1539,7 @@ private ValueAccessors preparePhysicalValuePointers( for (Map.Entry argumentWithType : arguments.get(false)) { Expression argument = argumentWithType.getKey(); - boolean isRuntimeEvaluated = !(argument instanceof SymbolReference) || runtimeEvaluatedSymbols.contains(Symbol.from(argument)); + boolean isRuntimeEvaluated = !(argument instanceof Reference) || runtimeEvaluatedSymbols.contains(Symbol.from(argument)); if (isRuntimeEvaluated) { List argumentInputSymbols = ImmutableList.copyOf(SymbolsExtractor.extractUnique(argument)); Supplier argumentProjectionSupplier = prepareArgumentProjection(argument, argumentInputSymbols); @@ -2049,7 +2049,7 @@ private RowExpression toRowExpression(Expression expression, Map getStaticFilter(Expression filterExpression) { DynamicFilters.ExtractResult extractDynamicFilterResult = extractDynamicFilters(filterExpression); Expression staticFilter = combineConjuncts(extractDynamicFilterResult.getStaticConjuncts()); - if (staticFilter.equals(TRUE_LITERAL)) { + if (staticFilter.equals(TRUE)) { return Optional.empty(); } return Optional.of(staticFilter); @@ -2438,22 +2438,22 @@ public PhysicalOperation visitJoin(JoinNode node, LocalExecutionPlanContext cont public PhysicalOperation visitSpatialJoin(SpatialJoinNode node, LocalExecutionPlanContext context) { Expression filterExpression = node.getFilter(); - List spatialFunctions = extractSupportedSpatialFunctions(filterExpression); - for (FunctionCall spatialFunction : spatialFunctions) { + List spatialFunctions = extractSupportedSpatialFunctions(filterExpression); + for (Call spatialFunction : spatialFunctions) { Optional operation = tryCreateSpatialJoin(context, node, removeExpressionFromFilter(filterExpression, spatialFunction), spatialFunction, Optional.empty(), Optional.empty()); if (operation.isPresent()) { return operation.get(); } } - List spatialComparisons = extractSupportedSpatialComparisons(filterExpression); - for (ComparisonExpression spatialComparison : spatialComparisons) { - if (spatialComparison.getOperator() == LESS_THAN || spatialComparison.getOperator() == LESS_THAN_OR_EQUAL) { + List spatialComparisons = extractSupportedSpatialComparisons(filterExpression); + for (Comparison spatialComparison : spatialComparisons) { + if (spatialComparison.operator() == LESS_THAN || spatialComparison.operator() == LESS_THAN_OR_EQUAL) { // ST_Distance(a, b) <= r - Expression radius = spatialComparison.getRight(); - if (radius instanceof SymbolReference && getSymbolReferences(node.getRight().getOutputSymbols()).contains(radius)) { - FunctionCall spatialFunction = (FunctionCall) spatialComparison.getLeft(); - Optional operation = tryCreateSpatialJoin(context, node, removeExpressionFromFilter(filterExpression, spatialComparison), spatialFunction, Optional.of(radius), Optional.of(spatialComparison.getOperator())); + Expression radius = spatialComparison.right(); + if (radius instanceof Reference && getSymbolReferences(node.getRight().getOutputSymbols()).contains(radius)) { + Call spatialFunction = (Call) spatialComparison.left(); + Optional operation = tryCreateSpatialJoin(context, node, removeExpressionFromFilter(filterExpression, spatialComparison), spatialFunction, Optional.of(radius), Optional.of(spatialComparison.operator())); if (operation.isPresent()) { return operation.get(); } @@ -2468,22 +2468,22 @@ private Optional tryCreateSpatialJoin( LocalExecutionPlanContext context, SpatialJoinNode node, Optional filterExpression, - FunctionCall spatialFunction, + Call spatialFunction, Optional radius, - Optional comparisonOperator) + Optional comparisonOperator) { - List arguments = spatialFunction.getArguments(); + List arguments = spatialFunction.arguments(); verify(arguments.size() == 2); - if (!(arguments.get(0) instanceof SymbolReference firstSymbol) || !(arguments.get(1) instanceof SymbolReference secondSymbol)) { + if (!(arguments.get(0) instanceof Reference firstSymbol) || !(arguments.get(1) instanceof Reference secondSymbol)) { return Optional.empty(); } PlanNode probeNode = node.getLeft(); - Set probeSymbols = getSymbolReferences(probeNode.getOutputSymbols()); + Set probeSymbols = getSymbolReferences(probeNode.getOutputSymbols()); PlanNode buildNode = node.getRight(); - Set buildSymbols = getSymbolReferences(buildNode.getOutputSymbols()); + Set buildSymbols = getSymbolReferences(buildNode.getOutputSymbols()); if (probeSymbols.contains(firstSymbol) && buildSymbols.contains(secondSymbol)) { return Optional.of(createSpatialLookupJoin( @@ -2514,13 +2514,13 @@ private Optional tryCreateSpatialJoin( private Optional removeExpressionFromFilter(Expression filter, Expression expression) { - Expression updatedJoinFilter = replaceExpression(filter, ImmutableMap.of(expression, TRUE_LITERAL)); - return updatedJoinFilter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(updatedJoinFilter); + Expression updatedJoinFilter = replaceExpression(filter, ImmutableMap.of(expression, TRUE)); + return updatedJoinFilter.equals(TRUE) ? Optional.empty() : Optional.of(updatedJoinFilter); } - private SpatialPredicate spatialTest(FunctionCall functionCall, boolean probeFirst, Optional comparisonOperator) + private SpatialPredicate spatialTest(Call call, boolean probeFirst, Optional comparisonOperator) { - CatalogSchemaFunctionName functionName = functionCall.getFunction().getName(); + CatalogSchemaFunctionName functionName = call.function().getName(); if (functionName.equals(builtinFunctionName(ST_CONTAINS))) { if (probeFirst) { return (buildGeometry, probeGeometry, radius) -> probeGeometry.contains(buildGeometry); @@ -2548,7 +2548,7 @@ private SpatialPredicate spatialTest(FunctionCall functionCall, boolean probeFir throw new UnsupportedOperationException("Unsupported spatial function: " + functionName); } - private Set getSymbolReferences(Collection symbols) + private Set getSymbolReferences(Collection symbols) { return symbols.stream().map(Symbol::toSymbolReference).collect(toImmutableSet()); } @@ -3709,7 +3709,7 @@ private AggregatorFactory buildAggregatorFactory( { List argumentChannels = new ArrayList<>(); for (Expression argument : aggregation.getArguments()) { - if (!(argument instanceof LambdaExpression)) { + if (!(argument instanceof Lambda)) { Symbol argumentSymbol = Symbol.from(argument); argumentChannels.add(source.getLayout().get(argumentSymbol)); } @@ -3779,15 +3779,15 @@ private AggregatorFactory buildAggregatorFactory( .mapToInt(value -> source.getLayout().get(value)) .findAny(); - List lambdaExpressions = aggregation.getArguments().stream() - .filter(LambdaExpression.class::isInstance) - .map(LambdaExpression.class::cast) + List lambdas = aggregation.getArguments().stream() + .filter(Lambda.class::isInstance) + .map(Lambda.class::cast) .collect(toImmutableList()); List functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream() .filter(FunctionType.class::isInstance) .map(FunctionType.class::cast) .collect(toImmutableList()); - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationImplementation.getLambdaInterfaces(), functionTypes); + List> lambdaProviders = makeLambdaProviders(lambdas, aggregationImplementation.getLambdaInterfaces(), functionTypes); return new AggregatorFactory( accumulatorFactory, @@ -3800,15 +3800,15 @@ private AggregatorFactory buildAggregatorFactory( lambdaProviders); } - private List> makeLambdaProviders(List lambdaExpressions, List> lambdaInterfaces, List functionTypes) + private List> makeLambdaProviders(List lambdas, List> lambdaInterfaces, List functionTypes) { List> lambdaProviders = new ArrayList<>(); - if (!lambdaExpressions.isEmpty()) { - verify(lambdaExpressions.size() == functionTypes.size()); - verify(lambdaExpressions.size() == lambdaInterfaces.size()); + if (!lambdas.isEmpty()) { + verify(lambdas.size() == functionTypes.size()); + verify(lambdas.size() == lambdaInterfaces.size()); - for (int i = 0; i < lambdaExpressions.size(); i++) { - LambdaExpression lambdaExpression = lambdaExpressions.get(i); + for (int i = 0; i < lambdas.size(); i++) { + Lambda lambdaExpression = lambdas.get(i); FunctionType functionType = functionTypes.get(i); // To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression, @@ -3823,11 +3823,11 @@ private List> makeLambdaProviders(List lambda // // TODO: Once the final aggregation function call representation is fixed, // the same mechanism in project and filter expression should be used here. - verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size()); + verify(lambdaExpression.arguments().size() == functionType.getArgumentTypes().size()); Map lambdaArgumentSymbolTypes = new HashMap<>(); - for (int j = 0; j < lambdaExpression.getArguments().size(); j++) { + for (int j = 0; j < lambdaExpression.arguments().size(); j++) { lambdaArgumentSymbolTypes.put( - lambdaExpression.getArguments().get(j), + lambdaExpression.arguments().get(j), functionType.getArgumentTypes().get(j)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index a5b4521f5ec3..a00e83f3e73a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -61,12 +61,12 @@ import io.trino.sql.analyzer.RelationId; import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.Row; import io.trino.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; import io.trino.sql.planner.iterative.IterativeOptimizer; @@ -144,8 +144,8 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; @@ -543,7 +543,7 @@ private RelationPlan getInsertPlan( expression = coerceOrCastToTableType(input, tableType, queryType); } if (!column.isNullable()) { - expression = new CoalesceExpression(expression, createNullNotAllowedFailExpression(column.getName(), tableType)); + expression = new Coalesce(expression, createNullNotAllowedFailExpression(column.getName(), tableType)); } assignments.put(output, expression); insertedColumnsBuilder.add(column); @@ -630,11 +630,11 @@ private Expression createNullNotAllowedFailExpression(String columnName, Type ty private static Function failIfPredicateIsNotMet(Metadata metadata, ErrorCodeSupplier errorCode, String errorMessage) { - FunctionCall fail = failFunction(metadata, errorCode, errorMessage); - return predicate -> ifExpression(predicate, TRUE_LITERAL, new Cast(fail, BOOLEAN)); + Call fail = failFunction(metadata, errorCode, errorMessage); + return predicate -> ifExpression(predicate, TRUE, new Cast(fail, BOOLEAN)); } - public static FunctionCall failFunction(Metadata metadata, ErrorCodeSupplier errorCode, String errorMessage) + public static Call failFunction(Metadata metadata, ErrorCodeSupplier errorCode, String errorMessage) { Object rawValue = Slices.utf8Slice(errorMessage); return BuiltinFunctionCallBuilder.resolve(metadata) @@ -812,11 +812,11 @@ private Expression noTruncationCast(Expression expression, Type fromType, Type t return ifExpression( // check if the trimmed value fits in the target type - new ComparisonExpression( + new Comparison( GREATER_THAN_OR_EQUAL, new Constant(BIGINT, (long) targetLength), - new CoalesceExpression( - new FunctionCall( + new Coalesce( + new Call( spaceTrimmedLength, ImmutableList.of(new Cast(expression, VARCHAR))), new Constant(BIGINT, 0L))), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java index 26a7d463ce00..5504f8c51911 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java @@ -13,16 +13,16 @@ */ package io.trino.sql.planner; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.NullIfExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.In; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import java.util.concurrent.atomic.AtomicBoolean; @@ -63,48 +63,48 @@ protected Void visitCast(Cast node, AtomicBoolean result) // except for the CAST(NULL AS x) case -- we should fix this at some point) // // Also, try_cast (i.e., safe cast) can return null - process(node.getExpression(), result); - result.compareAndSet(false, node.isSafe()); + process(node.expression(), result); + result.compareAndSet(false, node.safe()); return null; } @Override - protected Void visitNullIfExpression(NullIfExpression node, AtomicBoolean result) + protected Void visitNullIf(NullIf node, AtomicBoolean result) { result.set(true); return null; } @Override - protected Void visitInPredicate(InPredicate node, AtomicBoolean result) + protected Void visitIn(In node, AtomicBoolean result) { result.set(true); return null; } @Override - protected Void visitSearchedCaseExpression(SearchedCaseExpression node, AtomicBoolean result) + protected Void visitCase(Case node, AtomicBoolean result) { result.set(true); return null; } @Override - protected Void visitSimpleCaseExpression(SimpleCaseExpression node, AtomicBoolean result) + protected Void visitSwitch(Switch node, AtomicBoolean result) { result.set(true); return null; } @Override - protected Void visitSubscriptExpression(SubscriptExpression node, AtomicBoolean result) + protected Void visitSubscript(Subscript node, AtomicBoolean result) { result.set(true); return null; } @Override - protected Void visitFunctionCall(FunctionCall node, AtomicBoolean result) + protected Void visitCall(Call node, AtomicBoolean result) { // TODO: this should look at whether the return type of the function is annotated with @SqlNullable result.set(true); @@ -114,7 +114,7 @@ protected Void visitFunctionCall(FunctionCall node, AtomicBoolean result) @Override protected Void visitConstant(Constant node, AtomicBoolean result) { - if (node.getValue() == null) { + if (node.value() == null) { result.set(true); } return null; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java index 33bc4a9a0c72..32ed719e033a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java @@ -18,7 +18,7 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.LambdaExpression; +import io.trino.sql.ir.Lambda; import io.trino.sql.ir.NodeRef; import java.util.HashMap; @@ -68,7 +68,7 @@ public Void visitExpression(Expression node, Void context) translatedSubExpressions.put(NodeRef.of(node), result.get()); } else { - node.getChildren().forEach(this::process); + node.children().forEach(this::process); } return null; @@ -76,7 +76,7 @@ public Void visitExpression(Expression node, Void context) // TODO support lambda expressions for partial projection @Override - public Void visitLambdaExpression(LambdaExpression functionCall, Void context) + public Void visitLambda(Lambda functionCall, Void context) { return null; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java index d6813acb7ad0..85608d290a22 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java @@ -23,7 +23,7 @@ import io.trino.metadata.Metadata; import io.trino.spi.predicate.NullableValue; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import java.util.Collection; import java.util.HashSet; @@ -306,12 +306,12 @@ public boolean isConstant() public boolean isVariable() { - return expression instanceof SymbolReference; + return expression instanceof Reference; } public Symbol getColumn() { - verify(expression instanceof SymbolReference, "Expect the expression to be a SymbolReference"); + verify(expression instanceof Reference, "Expect the expression to be a SymbolReference"); return Symbol.from(expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index ddfb09d3cbe1..ed00243c210b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -41,18 +41,18 @@ import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.FieldId; import io.trino.sql.analyzer.RelationType; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.Subscript; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.RelationPlanner.PatternRecognitionComponents; import io.trino.sql.planner.plan.AggregationNode; @@ -149,9 +149,9 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.GroupingOperationRewriter.rewriteGroupingOperation; @@ -326,14 +326,14 @@ public RelationPlan planExpand(Query query) // 2. append filter to fail on non-empty result String recursionLimitExceededMessage = format("Recursion depth limit exceeded (%s). Use 'max_recursion_depth' session property to modify the limit.", maxRecursionDepth); Expression predicate = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN_OR_EQUAL, countSymbol.toSymbolReference(), new Constant(BIGINT, 0L)), new Cast( failFunction(plannerContext.getMetadata(), NOT_SUPPORTED, recursionLimitExceededMessage), BOOLEAN), - TRUE_LITERAL); + TRUE); FilterNode filterNode = new FilterNode(idAllocator.getNextId(), windowNode, predicate); recursionSteps.add(new NodeAndMappings(filterNode, checkConvergenceStep.getFields())); @@ -643,7 +643,7 @@ public PlanNode plan(Update node) // If the updated column is non-null, check that the value is not null if (mergeAnalysis.getNonNullableColumnHandles().contains(dataColumnHandle)) { String columnName = columnSchema.getName(); - rewritten = new CoalesceExpression(rewritten, new Cast(failFunction(metadata, INVALID_ARGUMENTS, "NULL value not allowed for NOT NULL column: " + columnName), columnSchema.getType())); + rewritten = new Coalesce(rewritten, new Cast(failFunction(metadata, INVALID_ARGUMENTS, "NULL value not allowed for NOT NULL column: " + columnName), columnSchema.getType())); } rowBuilder.add(rewritten); assignments.put(field, rewritten); @@ -659,7 +659,7 @@ public PlanNode plan(Update node) assignments.putIdentity(relationPlan.getFieldMappings().get(rowIdReference.getFieldIndex())); // Add the "present" field - rowBuilder.add(TRUE_LITERAL); + rowBuilder.add(TRUE); // Add the operation number rowBuilder.add(new Constant(TINYINT, (long) UPDATE_OPERATION_NUMBER)); @@ -704,7 +704,7 @@ public PlanNode plan(Update node) projectionAssignmentsBuilder.putIdentity(rowIdSymbol); projectionAssignmentsBuilder.put(mergeRowSymbol, mergeRow); projectionAssignmentsBuilder.put(caseNumberSymbol, new Constant(INTEGER, 0L)); - projectionAssignmentsBuilder.put(isDistinctSymbol, TRUE_LITERAL); + projectionAssignmentsBuilder.put(isDistinctSymbol, TRUE); ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), subPlanBuilder.getRoot(), projectionAssignmentsBuilder.build()); @@ -721,8 +721,8 @@ private PlanBuilder addCheckConstraints(List const Expression predicate = ifExpression( // When predicate evaluates to UNKNOWN (e.g. NULL > 100), it should not violate the check constraint. - new CoalesceExpression(coerceIfNecessary(analysis, constraint, symbol), TRUE_LITERAL), - TRUE_LITERAL, + new Coalesce(coerceIfNecessary(analysis, constraint, symbol), TRUE), + TRUE, new Cast(failFunction(plannerContext.getMetadata(), CONSTRAINT_VIOLATION, "Check constraint violation: " + constraint), BOOLEAN)); predicates.add(predicate); @@ -754,7 +754,7 @@ public MergeWriterNode plan(Merge merge) projections.putIdentities(planWithUniqueId.getRoot().getOutputSymbols()); Symbol presentColumn = symbolAllocator.newSymbol("present", BOOLEAN); - projections.put(presentColumn, TRUE_LITERAL); + projections.put(presentColumn, TRUE); RelationPlan planWithPresentColumn = new RelationPlan( new ProjectNode(idAllocator.getNextId(), planWithUniqueId.getRoot(), projections.build()), @@ -803,7 +803,7 @@ public MergeWriterNode plan(Merge merge) if (nonNullableColumnHandles.contains(dataColumnHandle)) { ColumnSchema columnSchema = dataColumnSchemas.get(fieldNumber); String columnName = columnSchema.getName(); - rewritten = new CoalesceExpression(rewritten, new Cast(failFunction(metadata, INVALID_ARGUMENTS, "Assigning NULL to non-null MERGE target table column " + columnName), columnSchema.getType())); + rewritten = new Coalesce(rewritten, new Cast(failFunction(metadata, INVALID_ARGUMENTS, "Assigning NULL to non-null MERGE target table column " + columnName), columnSchema.getType())); } rowBuilder.add(rewritten); assignments.put(field, rewritten); @@ -817,7 +817,7 @@ public MergeWriterNode plan(Merge merge) // Build the match condition for the MERGE case // Add a boolean column which is true if a target table row was matched - rowBuilder.add(new NotExpression(new IsNullPredicate(presentColumn.toSymbolReference()))); + rowBuilder.add(new Not(new IsNull(presentColumn.toSymbolReference()))); // Add the operation number rowBuilder.add(new Constant(TINYINT, (long) getMergeCaseOperationNumber(mergeCase))); @@ -827,7 +827,7 @@ public MergeWriterNode plan(Merge merge) Expression condition = presentColumn.toSymbolReference(); if (mergeCase instanceof MergeInsert) { - condition = new IsNullPredicate(presentColumn.toSymbolReference()); + condition = new IsNull(presentColumn.toSymbolReference()); } if (casePredicate.isPresent()) { @@ -856,13 +856,13 @@ public MergeWriterNode plan(Merge merge) ImmutableList.Builder rowBuilder = ImmutableList.builder(); dataColumnSchemas.forEach(columnSchema -> rowBuilder.add(new Constant(columnSchema.getType(), null))); - rowBuilder.add(new NotExpression(new IsNullPredicate(presentColumn.toSymbolReference()))); + rowBuilder.add(new Not(new IsNull(presentColumn.toSymbolReference()))); // The operation number rowBuilder.add(new Constant(TINYINT, -1L)); // The case number rowBuilder.add(new Constant(INTEGER, -1L)); - SearchedCaseExpression caseExpression = new SearchedCaseExpression(whenClauses.build(), Optional.of(new Row(rowBuilder.build()))); + Case caseExpression = new Case(whenClauses.build(), Optional.of(new Row(rowBuilder.build()))); Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", mergeAnalysis.getMergeRowType()); Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER); @@ -889,7 +889,7 @@ public MergeWriterNode plan(Merge merge) subPlanProject, Assignments.builder() .putIdentities(subPlanProject.getOutputSymbols()) - .put(caseNumberSymbol, new SubscriptExpression(INTEGER, mergeRowSymbol.toSymbolReference(), new Constant(INTEGER, (long) mergeAnalysis.getMergeRowType().getFields().size()))) + .put(caseNumberSymbol, new Subscript(INTEGER, mergeRowSymbol.toSymbolReference(), new Constant(INTEGER, (long) mergeAnalysis.getMergeRowType().getFields().size()))) .build()); // Mark distinct combinations of the unique_id value and the case_number @@ -898,13 +898,13 @@ public MergeWriterNode plan(Merge merge) // Raise an error if unique_id symbol is non-null and the unique_id/case_number combination was not distinct Expression filter = ifExpression( - LogicalExpression.and( - new NotExpression(isDistinctSymbol.toSymbolReference()), - new NotExpression(new IsNullPredicate(uniqueIdSymbol.toSymbolReference()))), + Logical.and( + new Not(isDistinctSymbol.toSymbolReference()), + new Not(new IsNull(uniqueIdSymbol.toSymbolReference()))), new Cast( failFunction(metadata, MERGE_TARGET_ROW_MULTIPLE_MATCHES, "One MERGE target table row matched more than one source row"), BOOLEAN), - TRUE_LITERAL); + TRUE); FilterNode filterNode = new FilterNode(idAllocator.getNextId(), markDistinctNode, filter); @@ -1521,11 +1521,11 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp Symbol offsetSymbol = coercions.get(frameOffset.get()); Expression zeroOffset = zeroOfType(offsetSymbol.getType()); Expression predicate = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN_OR_EQUAL, offsetSymbol.toSymbolReference(), zeroOffset), - TRUE_LITERAL, + TRUE, new Cast( failFunction(plannerContext.getMetadata(), INVALID_WINDOW_FRAME, "Window frame offset value must not be negative or null"), BOOLEAN)); @@ -1566,7 +1566,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp // Next, pre-project the function which combines sortKey with the offset. // Note: if frameOffset needs a coercion, it was added before by a call to coerce() method. ResolvedFunction function = frameBoundCalculationFunction.get(); - Expression functionCall = new FunctionCall( + Expression functionCall = new Call( function, ImmutableList.of( sortKeyCoercedForFrameBoundCalculation.toSymbolReference(), @@ -1622,8 +1622,8 @@ private FrameOffsetPlanAndSymbol planFrameOffset(PlanBuilder subPlan, Optional MAX_BIGINT_PRECISION) { } else { offsetToBigint = ifExpression( - new ComparisonExpression(LESS_THAN_OR_EQUAL, offsetSymbol.toSymbolReference(), new Constant(decimalType, Int128.valueOf(Long.MAX_VALUE))), + new Comparison(LESS_THAN_OR_EQUAL, offsetSymbol.toSymbolReference(), new Constant(decimalType, Int128.valueOf(Long.MAX_VALUE))), new Cast(offsetSymbol.toSymbolReference(), BIGINT), new Constant(BIGINT, Long.MAX_VALUE)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index d3d522b58c42..b725a61d4d63 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -53,12 +53,13 @@ import io.trino.sql.analyzer.PatternRecognitionAnalysis.ScalarInputDescriptor; import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; +import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.IrUtils; import io.trino.sql.ir.Row; import io.trino.sql.planner.QueryPlanner.PlanAndMappings; @@ -176,7 +177,6 @@ import static io.trino.sql.analyzer.PatternRecognitionAnalysis.NavigationAnchor.LAST; import static io.trino.sql.analyzer.PatternRecognitionAnalysis.NavigationMode.RUNNING; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; @@ -203,7 +203,6 @@ import static io.trino.sql.tree.PatternSearchMode.Mode.INITIAL; import static io.trino.sql.tree.SkipTo.Position.PAST_LAST; import static io.trino.type.Json2016Type.JSON_2016; -import static java.lang.Boolean.TRUE; import static java.util.Objects.requireNonNull; class RelationPlanner @@ -259,16 +258,16 @@ public static JoinType mapJoinType(Join.Type joinType) }; } - private ComparisonExpression.Operator mapComparisonOperator(io.trino.sql.tree.ComparisonExpression.Operator operator) + private Comparison.Operator mapComparisonOperator(io.trino.sql.tree.ComparisonExpression.Operator operator) { return switch (operator) { - case EQUAL -> ComparisonExpression.Operator.EQUAL; - case NOT_EQUAL -> ComparisonExpression.Operator.NOT_EQUAL; - case LESS_THAN -> ComparisonExpression.Operator.LESS_THAN; - case LESS_THAN_OR_EQUAL -> ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; - case GREATER_THAN -> ComparisonExpression.Operator.GREATER_THAN; - case GREATER_THAN_OR_EQUAL -> ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; - case IS_DISTINCT_FROM -> ComparisonExpression.Operator.IS_DISTINCT_FROM; + case EQUAL -> Comparison.Operator.EQUAL; + case NOT_EQUAL -> Comparison.Operator.NOT_EQUAL; + case LESS_THAN -> Comparison.Operator.LESS_THAN; + case LESS_THAN_OR_EQUAL -> Comparison.Operator.LESS_THAN_OR_EQUAL; + case GREATER_THAN -> Comparison.Operator.GREATER_THAN; + case GREATER_THAN_OR_EQUAL -> Comparison.Operator.GREATER_THAN_OR_EQUAL; + case IS_DISTINCT_FROM -> Comparison.Operator.IS_DISTINCT_FROM; }; } @@ -411,8 +410,8 @@ public RelationPlan addCheckConstraints(List const Expression predicate = ifExpression( // When predicate evaluates to UNKNOWN (e.g. NULL > 100), it should not violate the check constraint. - new CoalesceExpression(coerceIfNecessary(analysis, constraint, planBuilder.rewrite(constraint)), TRUE_LITERAL), - TRUE_LITERAL, + new Coalesce(coerceIfNecessary(analysis, constraint, planBuilder.rewrite(constraint)), Booleans.TRUE), + Booleans.TRUE, new Cast(failFunction(plannerContext.getMetadata(), CONSTRAINT_VIOLATION, "Check constraint violation: " + constraint), BOOLEAN)); planBuilder = planBuilder.withNewRoot(new FilterNode( @@ -752,7 +751,7 @@ public PatternRecognitionComponents planPatternRecognitionComponents( measureOutputs.build(), skipToLabels, mapSkipToPosition(skipTo.map(SkipTo::getPosition).orElse(PAST_LAST)), - searchMode.map(mode -> mode.getMode() == INITIAL).orElse(TRUE), + searchMode.map(mode -> mode.getMode() == INITIAL).orElse(Boolean.TRUE), RowPatternToIrRewriter.rewrite(pattern, analysis), rewrittenVariableDefinitions.buildOrThrow()); } @@ -996,7 +995,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende } else { postInnerJoinConditions.add( - new ComparisonExpression(mapComparisonOperator(joinConditionComparisonOperators.get(i)), + new Comparison(mapComparisonOperator(joinConditionComparisonOperators.get(i)), leftCoercions.get(leftComparisonExpressions.get(i)).toSymbolReference(), rightCoercions.get(rightComparisonExpressions.get(i)).toSymbolReference())); } @@ -1176,7 +1175,7 @@ If casts are redundant (due to column type and common type being equal), for (Identifier column : joinColumns) { Symbol output = symbolAllocator.newSymbol(column.getValue(), analysis.getType(column)); outputs.add(output); - assignments.put(output, new CoalesceExpression( + assignments.put(output, new Coalesce( leftJoinColumns.get(column).toSymbolReference(), rightJoinColumns.get(column).toSymbolReference())); } @@ -1252,7 +1251,7 @@ private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Latera Expression rewrittenFilterCondition; if (join.getCriteria().isEmpty()) { - rewrittenFilterCondition = TRUE_LITERAL; + rewrittenFilterCondition = Booleans.TRUE; } else { JoinCriteria criteria = join.getCriteria().get(); @@ -1367,7 +1366,7 @@ private RelationPlan planJoinJsonTable(PlanBuilder leftPlan, List leftFi // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, jsonTable.getErrorBehavior().orElse(JsonTable.ErrorBehavior.EMPTY) == JsonTable.ErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(inputExpression); - Expression inputJson = new FunctionCall(inputToJson, ImmutableList.of(coerced.get(inputExpression).toSymbolReference(), failOnError)); + Expression inputJson = new Call(inputToJson, ImmutableList.of(coerced.get(inputExpression).toSymbolReference(), failOnError)); // apply the input functions to the JSON path parameters having FORMAT, // and collect all JSON path parameters in a Row @@ -1498,7 +1497,7 @@ else if (jsonTable.getPlan().orElseThrow() instanceof JsonTableDefaultPlan defau Constant errorBehavior = new Constant(TINYINT, (long) queryColumn.getErrorBehavior().orElse(defaultErrorOnError ? ERROR : NULL).ordinal()); Constant omitQuotes = new Constant(BOOLEAN, queryColumn.getQuotesBehavior().orElse(KEEP) == OMIT); ResolvedFunction outputFunction = analysis.getJsonOutputFunction(queryColumn); - Expression result = new FunctionCall(outputFunction, ImmutableList.of(properOutput.toSymbolReference(), errorBehavior, omitQuotes)); + Expression result = new Call(outputFunction, ImmutableList.of(properOutput.toSymbolReference(), errorBehavior, omitQuotes)); // cast to declared returned type Type expectedType = jsonTableRelationType.getFieldByIndex(i).getType(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java b/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java index cdfbbf0586ce..b5e7c5cc9b87 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java @@ -14,8 +14,8 @@ package io.trino.sql.planner; import io.trino.metadata.ResolvedFunction; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import java.util.ArrayList; import java.util.List; @@ -51,8 +51,8 @@ public ResolvedFunctionCallBuilder setArguments(List values) return this; } - public FunctionCall build() + public Call build() { - return new FunctionCall(resolvedFunction, argumentValues); + return new Call(resolvedFunction, argumentValues); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java index fa50dbaf2867..a0ec155fb424 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java @@ -15,12 +15,12 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.join.SortedPositionLinks; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrUtils; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import java.util.List; import java.util.Optional; @@ -28,8 +28,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static java.util.Collections.singletonList; import static java.util.Comparator.comparing; import static java.util.function.Function.identity; @@ -104,15 +104,15 @@ protected Optional visitExpression(Expression expression, } @Override - protected Optional visitComparisonExpression(ComparisonExpression comparison, Void context) + protected Optional visitComparison(Comparison comparison, Void context) { - return switch (comparison.getOperator()) { + return switch (comparison.operator()) { case GREATER_THAN, GREATER_THAN_OR_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> { - Optional sortChannel = asBuildSymbolReference(buildSymbols, comparison.getRight()); - boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getLeft()); + Optional sortChannel = asBuildSymbolReference(buildSymbols, comparison.right()); + boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.left()); if (sortChannel.isEmpty()) { - sortChannel = asBuildSymbolReference(buildSymbols, comparison.getLeft()); - hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight()); + sortChannel = asBuildSymbolReference(buildSymbols, comparison.left()); + hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.right()); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { yield sortChannel.map(symbolReference -> new SortExpressionContext(symbolReference, singletonList(comparison))); @@ -124,22 +124,22 @@ protected Optional visitComparisonExpression(ComparisonEx } @Override - protected Optional visitBetweenPredicate(BetweenPredicate node, Void context) + protected Optional visitBetween(Between node, Void context) { - Optional result = visitComparisonExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()), context); + Optional result = visitComparison(new Comparison(GREATER_THAN_OR_EQUAL, node.value(), node.min()), context); if (result.isPresent()) { return result; } - return visitComparisonExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()), context); + return visitComparison(new Comparison(LESS_THAN_OR_EQUAL, node.value(), node.max()), context); } } - private static Optional asBuildSymbolReference(Set buildLayout, Expression expression) + private static Optional asBuildSymbolReference(Set buildLayout, Expression expression) { // Currently we only support symbol as sort expression on build side - if (expression instanceof SymbolReference symbolReference) { - if (buildLayout.contains(new Symbol(symbolReference.type(), symbolReference.getName()))) { - return Optional.of(symbolReference); + if (expression instanceof Reference reference) { + if (buildLayout.contains(new Symbol(reference.type(), reference.name()))) { + return Optional.of(reference); } } return Optional.empty(); @@ -165,7 +165,7 @@ public BuildSymbolReferenceFinder(Set buildSymbols) @Override protected Boolean visitExpression(Expression node, Void context) { - for (Expression child : node.getChildren()) { + for (Expression child : node.children()) { if (process(child, context)) { return true; } @@ -174,9 +174,9 @@ protected Boolean visitExpression(Expression node, Void context) } @Override - protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context) + protected Boolean visitReference(Reference reference, Void context) { - return buildSymbols.contains(symbolReference.getName()); + return buildSymbols.contains(reference.name()); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java index 5c9cf007f005..b0524a1bd340 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java @@ -26,7 +26,7 @@ import io.trino.sql.analyzer.Scope; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.Not; import io.trino.sql.ir.Row; import io.trino.sql.planner.QueryPlanner.PlanAndMappings; import io.trino.sql.planner.plan.ApplyNode; @@ -58,7 +58,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Streams.stream; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; import static java.lang.String.format; @@ -260,7 +260,7 @@ private PlanBuilder planScalarSubquery(PlanBuilder subPlan, Cluster> { @Override - protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder builder) + protected Void visitReference(Reference node, ImmutableList.Builder builder) { builder.add(Symbol.from(node)); return null; } @Override - protected Void visitLambdaExpression(LambdaExpression node, ImmutableList.Builder context) + protected Void visitLambda(Lambda node, ImmutableList.Builder context) { // Symbols in lambda expression are bound to lambda arguments, so no need to extract them return null; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index 09dd6b0453b0..833588180168 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -40,9 +40,23 @@ import io.trino.sql.analyzer.ResolvedField; import io.trino.sql.analyzer.Scope; import io.trino.sql.analyzer.TypeSignatureTranslator; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.Array; @@ -130,8 +144,8 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.ExpressionAnalyzer.JSON_NO_PARAMETERS_ROW_TYPE; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; import static io.trino.sql.tree.JsonQuery.EmptyOrErrorBehavior.ERROR; @@ -286,7 +300,7 @@ private io.trino.sql.ir.Expression translateExpression(Expression expression) private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot) { - Optional mapped = tryGetMapping(expr); + Optional mapped = tryGetMapping(expr); io.trino.sql.ir.Expression result; if (mapped.isPresent()) { @@ -358,7 +372,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot) private io.trino.sql.ir.Expression translate(NullIfExpression expression) { - return new io.trino.sql.ir.NullIfExpression( + return new NullIf( translateExpression(expression.getFirst()), translateExpression(expression.getSecond())); } @@ -367,7 +381,7 @@ private io.trino.sql.ir.Expression translate(ArithmeticUnaryExpression expressio { return switch (expression.getSign()) { case PLUS -> translateExpression(expression.getValue()); - case MINUS -> new ArithmeticNegation(translateExpression(expression.getValue())); + case MINUS -> new Negation(translateExpression(expression.getValue())); }; } @@ -386,7 +400,7 @@ private io.trino.sql.ir.Expression translate(IntervalLiteral expression) private io.trino.sql.ir.Expression translate(SearchedCaseExpression expression) { - return new io.trino.sql.ir.SearchedCaseExpression( + return new Case( expression.getWhenClauses().stream() .map(clause -> new io.trino.sql.ir.WhenClause( translateExpression(clause.getOperand()), @@ -397,7 +411,7 @@ private io.trino.sql.ir.Expression translate(SearchedCaseExpression expression) private io.trino.sql.ir.Expression translate(SimpleCaseExpression expression) { - return new io.trino.sql.ir.SimpleCaseExpression( + return new Switch( translateExpression(expression.getOperand()), expression.getWhenClauses().stream() .map(clause -> new io.trino.sql.ir.WhenClause( @@ -409,7 +423,7 @@ private io.trino.sql.ir.Expression translate(SimpleCaseExpression expression) private io.trino.sql.ir.Expression translate(InPredicate expression) { - return new io.trino.sql.ir.InPredicate( + return new In( translateExpression(expression.getValue()), ((InListExpression) expression.getValueList()).getValues().stream() .map(this::translateExpression) @@ -437,7 +451,7 @@ private io.trino.sql.ir.Expression translate(BinaryLiteral expression) private io.trino.sql.ir.Expression translate(BetweenPredicate expression) { - return new io.trino.sql.ir.BetweenPredicate( + return new Between( translateExpression(expression.getValue()), translateExpression(expression.getMin()), translateExpression(expression.getMax())); @@ -445,19 +459,19 @@ private io.trino.sql.ir.Expression translate(BetweenPredicate expression) private io.trino.sql.ir.Expression translate(IsNullPredicate expression) { - return new io.trino.sql.ir.IsNullPredicate(translateExpression(expression.getValue())); + return new IsNull(translateExpression(expression.getValue())); } private io.trino.sql.ir.Expression translate(IsNotNullPredicate expression) { - return new io.trino.sql.ir.NotExpression( - new io.trino.sql.ir.IsNullPredicate( + return new Not( + new IsNull( translateExpression(expression.getValue()))); } private io.trino.sql.ir.Expression translate(CoalesceExpression expression) { - return new io.trino.sql.ir.CoalesceExpression(expression.getOperands().stream() + return new Coalesce(expression.getOperands().stream() .map(this::translateExpression) .collect(toImmutableList())); } @@ -491,10 +505,10 @@ private io.trino.sql.ir.Expression translate(DecimalLiteral expression) private io.trino.sql.ir.Expression translate(LogicalExpression expression) { - return new io.trino.sql.ir.LogicalExpression( + return new Logical( switch (expression.getOperator()) { - case AND -> io.trino.sql.ir.LogicalExpression.Operator.AND; - case OR -> io.trino.sql.ir.LogicalExpression.Operator.OR; + case AND -> Logical.Operator.AND; + case OR -> Logical.Operator.OR; }, expression.getTerms().stream() .map(this::translateExpression) @@ -504,11 +518,11 @@ private io.trino.sql.ir.Expression translate(LogicalExpression expression) private io.trino.sql.ir.Expression translate(BooleanLiteral expression) { if (expression.equals(BooleanLiteral.TRUE_LITERAL)) { - return TRUE_LITERAL; + return TRUE; } if (expression.equals(BooleanLiteral.FALSE_LITERAL)) { - return FALSE_LITERAL; + return FALSE; } throw new IllegalArgumentException("Unknown boolean literal: " + expression); @@ -516,7 +530,7 @@ private io.trino.sql.ir.Expression translate(BooleanLiteral expression) private io.trino.sql.ir.Expression translate(NotExpression expression) { - return new io.trino.sql.ir.NotExpression(translateExpression(expression.getValue())); + return new Not(translateExpression(expression.getValue())); } private io.trino.sql.ir.Expression translate(Row expression) @@ -528,15 +542,15 @@ private io.trino.sql.ir.Expression translate(Row expression) private io.trino.sql.ir.Expression translate(ComparisonExpression expression) { - return new io.trino.sql.ir.ComparisonExpression( + return new Comparison( switch (expression.getOperator()) { - case EQUAL -> io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; - case NOT_EQUAL -> io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; - case LESS_THAN -> io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; - case LESS_THAN_OR_EQUAL -> io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; - case GREATER_THAN -> io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; - case GREATER_THAN_OR_EQUAL -> io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; - case IS_DISTINCT_FROM -> io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; + case EQUAL -> Comparison.Operator.EQUAL; + case NOT_EQUAL -> Comparison.Operator.NOT_EQUAL; + case LESS_THAN -> Comparison.Operator.LESS_THAN; + case LESS_THAN_OR_EQUAL -> Comparison.Operator.LESS_THAN_OR_EQUAL; + case GREATER_THAN -> Comparison.Operator.GREATER_THAN; + case GREATER_THAN_OR_EQUAL -> Comparison.Operator.GREATER_THAN_OR_EQUAL; + case IS_DISTINCT_FROM -> Comparison.Operator.IS_DISTINCT_FROM; }, translateExpression(expression.getLeft()), translateExpression(expression.getRight())); @@ -565,14 +579,14 @@ private io.trino.sql.ir.Expression translate(ArithmeticBinaryExpression expressi case MODULUS -> OperatorType.MODULUS; }; - return new io.trino.sql.ir.ArithmeticBinaryExpression( + return new Arithmetic( plannerContext.getMetadata().resolveOperator(operatorType, ImmutableList.of(getCoercedType(expression.getLeft()), getCoercedType(expression.getRight()))), switch (expression.getOperator()) { - case ADD -> io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; - case SUBTRACT -> io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; - case MULTIPLY -> io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; - case DIVIDE -> io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; - case MODULUS -> io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; + case ADD -> Arithmetic.Operator.ADD; + case SUBTRACT -> Arithmetic.Operator.SUBTRACT; + case MULTIPLY -> Arithmetic.Operator.MULTIPLY; + case DIVIDE -> Arithmetic.Operator.DIVIDE; + case MODULUS -> Arithmetic.Operator.MODULUS; }, translateExpression(expression.getLeft()), translateExpression(expression.getRight())); @@ -626,7 +640,7 @@ private io.trino.sql.ir.Expression translate(FunctionCall expression) ResolvedFunction resolvedFunction = analysis.getResolvedFunction(expression); checkArgument(resolvedFunction != null, "Function has not been analyzed: %s", expression); - return new io.trino.sql.ir.FunctionCall( + return new Call( resolvedFunction, expression.getArguments().stream() .map(this::translateExpression) @@ -656,7 +670,7 @@ private io.trino.sql.ir.Expression translate(DereferenceExpression expression) checkState(index >= 0, "could not find field name: %s", fieldName); - return new io.trino.sql.ir.SubscriptExpression( + return new Subscript( rowType.getFields().get(index).getType(), translateExpression(expression.getBase()), new Constant(INTEGER, (long) (index + 1))); @@ -682,7 +696,7 @@ private io.trino.sql.ir.Expression translate(Array expression) private io.trino.sql.ir.Expression translate(CurrentCatalog unused) { - return new io.trino.sql.ir.FunctionCall( + return new Call( plannerContext.getMetadata() .resolveBuiltinFunction("$current_catalog", ImmutableList.of()), ImmutableList.of()); @@ -690,7 +704,7 @@ private io.trino.sql.ir.Expression translate(CurrentCatalog unused) private io.trino.sql.ir.Expression translate(CurrentSchema unused) { - return new io.trino.sql.ir.FunctionCall( + return new Call( plannerContext.getMetadata() .resolveBuiltinFunction("$current_schema", ImmutableList.of()), ImmutableList.of()); @@ -698,7 +712,7 @@ private io.trino.sql.ir.Expression translate(CurrentSchema unused) private io.trino.sql.ir.Expression translate(CurrentPath unused) { - return new io.trino.sql.ir.FunctionCall( + return new Call( plannerContext.getMetadata() .resolveBuiltinFunction("$current_path", ImmutableList.of()), ImmutableList.of()); @@ -706,7 +720,7 @@ private io.trino.sql.ir.Expression translate(CurrentPath unused) private io.trino.sql.ir.Expression translate(CurrentUser unused) { - return new io.trino.sql.ir.FunctionCall( + return new Call( plannerContext.getMetadata() .resolveBuiltinFunction("$current_user", ImmutableList.of()), ImmutableList.of()); @@ -828,7 +842,7 @@ private io.trino.sql.ir.Expression translate(AtTimeZone node) Type timeZoneType = analysis.getType(node.getTimeZone()); io.trino.sql.ir.Expression timeZone = translateExpression(node.getTimeZone()); - io.trino.sql.ir.FunctionCall call; + Call call; if (valueType instanceof TimeType type) { call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$at_timezone") @@ -873,7 +887,7 @@ private io.trino.sql.ir.Expression translate(Format node) .map(analysis::getType) .collect(toImmutableList()); - io.trino.sql.ir.FunctionCall call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) + Call call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName(FormatFunction.NAME) .addArgument(VARCHAR, arguments.get(0)) .addArgument(RowType.anonymous(argumentTypes.subList(1, arguments.size())), new io.trino.sql.ir.Row(arguments.subList(1, arguments.size()))) @@ -889,7 +903,7 @@ private io.trino.sql.ir.Expression translate(TryExpression node) return BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName(TryFunction.NAME) - .addArgument(new FunctionType(ImmutableList.of(), type), new io.trino.sql.ir.LambdaExpression(ImmutableList.of(), expression)) + .addArgument(new FunctionType(ImmutableList.of(), type), new Lambda(ImmutableList.of(), expression)) .build(); } @@ -899,7 +913,7 @@ private io.trino.sql.ir.Expression translate(LikePredicate node) io.trino.sql.ir.Expression pattern = translateExpression(node.getPattern()); Optional escape = node.getEscape().map(this::translateExpression); - io.trino.sql.ir.FunctionCall patternCall; + Call patternCall; if (escape.isPresent()) { patternCall = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName(LIKE_PATTERN_FUNCTION_NAME) @@ -914,7 +928,7 @@ private io.trino.sql.ir.Expression translate(LikePredicate node) .build(); } - io.trino.sql.ir.FunctionCall call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) + Call call = BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName(LIKE_FUNCTION_NAME) .addArgument(analysis.getType(node.getValue()), value) .addArgument(LIKE_PATTERN, patternCall) @@ -934,7 +948,7 @@ private io.trino.sql.ir.Expression translate(Trim node) .map(this::translateExpression) .ifPresent(arguments::add); - return new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); + return new Call(resolvedFunction, arguments.build()); } private io.trino.sql.ir.Expression translate(SubscriptExpression node) @@ -944,12 +958,12 @@ private io.trino.sql.ir.Expression translate(SubscriptExpression node) // Do not rewrite subscript index into symbol. Row subscript index is required to be a literal. io.trino.sql.ir.Expression rewrittenBase = translateExpression(node.getBase()); LongLiteral index = (LongLiteral) node.getIndex(); - return new io.trino.sql.ir.SubscriptExpression( + return new Subscript( analysis.getType(node), rewrittenBase, new Constant(INTEGER, index.getParsedValue())); } - return new io.trino.sql.ir.SubscriptExpression( + return new Subscript( analysis.getType(node), translateExpression(node.getBase()), translateExpression(node.getIndex())); @@ -964,7 +978,7 @@ private io.trino.sql.ir.Expression translate(LambdaExpression node) newArguments.add(lambdaArguments.get(NodeRef.of(argument))); } io.trino.sql.ir.Expression rewrittenBody = translateExpression(node.getBody()); - return new io.trino.sql.ir.LambdaExpression(newArguments.build(), rewrittenBody); + return new Lambda(newArguments.build(), rewrittenBody); } private io.trino.sql.ir.Expression translate(Parameter node) @@ -981,7 +995,7 @@ private io.trino.sql.ir.Expression translate(JsonExists node) // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, node.getErrorBehavior() == JsonExists.ErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(node.getJsonPathInvocation().getInputExpression()); - io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson, ImmutableList.of( + io.trino.sql.ir.Expression input = new Call(inputToJson, ImmutableList.of( translateExpression(node.getJsonPathInvocation().getInputExpression()), failOnError)); @@ -1004,7 +1018,7 @@ private io.trino.sql.ir.Expression translate(JsonExists node) .add(orderedParameters.getParametersRow()) .add(new Constant(TINYINT, (long) node.getErrorBehavior().ordinal())); - return new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); + return new Call(resolvedFunction, arguments.build()); } private io.trino.sql.ir.Expression translate(JsonValue node) @@ -1015,7 +1029,7 @@ private io.trino.sql.ir.Expression translate(JsonValue node) // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, node.getErrorBehavior() == JsonValue.EmptyOrErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(node.getJsonPathInvocation().getInputExpression()); - io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson, ImmutableList.of( + io.trino.sql.ir.Expression input = new Call(inputToJson, ImmutableList.of( translateExpression(node.getJsonPathInvocation().getInputExpression()), failOnError)); @@ -1045,7 +1059,7 @@ private io.trino.sql.ir.Expression translate(JsonValue node) .map(this::translateExpression) .orElseGet(() -> new Constant(resolvedFunction.getSignature().getReturnType(), null))); - return new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); + return new Call(resolvedFunction, arguments.build()); } private io.trino.sql.ir.Expression translate(JsonQuery node) @@ -1056,7 +1070,7 @@ private io.trino.sql.ir.Expression translate(JsonQuery node) // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, node.getErrorBehavior() == ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(node.getJsonPathInvocation().getInputExpression()); - io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson, ImmutableList.of( + io.trino.sql.ir.Expression input = new Call(inputToJson, ImmutableList.of( translateExpression(node.getJsonPathInvocation().getInputExpression()), failOnError)); @@ -1081,13 +1095,13 @@ private io.trino.sql.ir.Expression translate(JsonQuery node) .add(new Constant(TINYINT, (long) node.getEmptyBehavior().ordinal())) .add(new Constant(TINYINT, (long) node.getErrorBehavior().ordinal())); - io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); + io.trino.sql.ir.Expression function = new Call(resolvedFunction, arguments.build()); // apply function to format output Constant errorBehavior = new Constant(TINYINT, (long) node.getErrorBehavior().ordinal()); Constant omitQuotes = new Constant(BOOLEAN, node.getQuotesBehavior().orElse(KEEP) == OMIT); ResolvedFunction outputFunction = analysis.getJsonOutputFunction(node); - io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction, ImmutableList.of(function, errorBehavior, omitQuotes)); + io.trino.sql.ir.Expression result = new Call(outputFunction, ImmutableList.of(function, errorBehavior, omitQuotes)); // cast to requested returned type Type returnedType = node.getReturnedType() @@ -1130,7 +1144,7 @@ private io.trino.sql.ir.Expression translate(JsonObject node) io.trino.sql.ir.Expression rewrittenValue = translateExpression(value); ResolvedFunction valueToJson = analysis.getJsonInputFunction(value); if (valueToJson != null) { - values.add(new io.trino.sql.ir.FunctionCall(valueToJson, ImmutableList.of(rewrittenValue, TRUE_LITERAL))); + values.add(new Call(valueToJson, ImmutableList.of(rewrittenValue, TRUE))); } else { values.add(rewrittenValue); @@ -1143,18 +1157,18 @@ private io.trino.sql.ir.Expression translate(JsonObject node) List arguments = ImmutableList.builder() .add(keysRow) .add(valuesRow) - .add(node.isNullOnNull() ? TRUE_LITERAL : FALSE_LITERAL) - .add(node.isUniqueKeys() ? TRUE_LITERAL : FALSE_LITERAL) + .add(node.isNullOnNull() ? TRUE : FALSE) + .add(node.isUniqueKeys() ? TRUE : FALSE) .build(); - io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments); + io.trino.sql.ir.Expression function = new Call(resolvedFunction, arguments); // apply function to format output ResolvedFunction outputFunction = analysis.getJsonOutputFunction(node); - io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction, ImmutableList.of( + io.trino.sql.ir.Expression result = new Call(outputFunction, ImmutableList.of( function, new Constant(TINYINT, (long) ERROR.ordinal()), - FALSE_LITERAL)); + FALSE)); // cast to requested returned type Type returnedType = node.getReturnedType() @@ -1189,7 +1203,7 @@ private io.trino.sql.ir.Expression translate(JsonArray node) io.trino.sql.ir.Expression rewrittenElement = translateExpression(element); ResolvedFunction elementToJson = analysis.getJsonInputFunction(element); if (elementToJson != null) { - elements.add(new io.trino.sql.ir.FunctionCall(elementToJson, ImmutableList.of(rewrittenElement, TRUE_LITERAL))); + elements.add(new Call(elementToJson, ImmutableList.of(rewrittenElement, TRUE))); } else { elements.add(rewrittenElement); @@ -1200,17 +1214,17 @@ private io.trino.sql.ir.Expression translate(JsonArray node) List arguments = ImmutableList.builder() .add(elementsRow) - .add(node.isNullOnNull() ? TRUE_LITERAL : FALSE_LITERAL) + .add(node.isNullOnNull() ? TRUE : FALSE) .build(); - io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments); + io.trino.sql.ir.Expression function = new Call(resolvedFunction, arguments); // apply function to format output ResolvedFunction outputFunction = analysis.getJsonOutputFunction(node); - io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction, ImmutableList.of( + io.trino.sql.ir.Expression result = new Call(outputFunction, ImmutableList.of( function, new Constant(TINYINT, (long) ERROR.ordinal()), - FALSE_LITERAL)); + FALSE)); // cast to requested returned type Type returnedType = node.getReturnedType() @@ -1226,7 +1240,7 @@ private io.trino.sql.ir.Expression translate(JsonArray node) return result; } - private Optional tryGetMapping(Expression expression) + private Optional tryGetMapping(Expression expression) { Symbol symbol = substitutions.get(NodeRef.of(expression)); if (symbol == null) { @@ -1277,7 +1291,7 @@ public ParametersRow getParametersRow( ResolvedFunction parameterToJson = analysis.getJsonInputFunction(pathParameters.get(i).getParameter()); io.trino.sql.ir.Expression rewrittenParameter = rewrittenPathParameters.get(i); if (parameterToJson != null) { - parameters.add(new io.trino.sql.ir.FunctionCall(parameterToJson, ImmutableList.of(rewrittenParameter, failOnError))); + parameters.add(new Call(parameterToJson, ImmutableList.of(rewrittenParameter, failOnError))); } else { parameters.add(rewrittenParameter); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java index e8fa678df477..22c77e15ba7c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java @@ -22,10 +22,10 @@ import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.BuiltinFunctionCallBuilder; import io.trino.sql.planner.iterative.Rule; @@ -59,7 +59,7 @@ public Set> rules() private static Expression rewrite(Expression expression, Metadata metadata) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return expression; } return ExpressionTreeRewriter.rewriteWith(new Visitor(metadata), expression); @@ -76,19 +76,19 @@ public Visitor(Metadata metadata) } @Override - public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter treeRewriter) { - FunctionCall rewritten = treeRewriter.defaultRewrite(node, context); - if (node.getFunction().getName().equals(ARRAY_DISTINCT_NAME) && - getOnlyElement(rewritten.getArguments()) instanceof FunctionCall) { - Expression expression = getOnlyElement(rewritten.getArguments()); - FunctionCall functionCall = (FunctionCall) expression; - ResolvedFunction resolvedFunction = functionCall.getFunction(); + Call rewritten = treeRewriter.defaultRewrite(node, context); + if (node.function().getName().equals(ARRAY_DISTINCT_NAME) && + getOnlyElement(rewritten.arguments()) instanceof Call) { + Expression expression = getOnlyElement(rewritten.arguments()); + Call call = (Call) expression; + ResolvedFunction resolvedFunction = call.function(); if (resolvedFunction.getName().equals(ARRAY_SORT_NAME)) { - List arraySortArguments = functionCall.getArguments(); + List arraySortArguments = call.arguments(); List arraySortArgumentsTypes = resolvedFunction.getSignature().getArgumentTypes(); - FunctionCall arrayDistinctCall = BuiltinFunctionCallBuilder.resolve(metadata) + Call arrayDistinctCall = BuiltinFunctionCallBuilder.resolve(metadata) .setName(ArrayDistinctFunction.NAME) .setArguments( ImmutableList.of(arraySortArgumentsTypes.get(0)), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java index 843fa0de3587..440d9cfb95d4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java @@ -21,20 +21,20 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.spi.type.DateType.DATE; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; public final class CanonicalizeExpressionRewriter { @@ -47,7 +47,7 @@ private CanonicalizeExpressionRewriter() {} public static Expression rewrite(Expression expression, PlannerContext plannerContext) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return expression; } @@ -66,12 +66,12 @@ public Visitor(PlannerContext plannerContext) @SuppressWarnings("ArgumentSelectionDefectChecker") @Override - public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter treeRewriter) { // if we have a comparison of the form , normalize it to // - if (isConstant(node.getLeft()) && !isConstant(node.getRight())) { - node = new ComparisonExpression(node.getOperator().flip(), node.getRight(), node.getLeft()); + if (isConstant(node.left()) && !isConstant(node.right())) { + node = new Comparison(node.operator().flip(), node.right(), node.left()); } return treeRewriter.defaultRewrite(node, context); @@ -79,25 +79,25 @@ public Expression rewriteComparisonExpression(ComparisonExpression node, Void co @SuppressWarnings("ArgumentSelectionDefectChecker") @Override - public Expression rewriteArithmeticBinary(ArithmeticBinaryExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteArithmetic(Arithmetic node, Void context, ExpressionTreeRewriter treeRewriter) { - if (node.getOperator() == MULTIPLY || node.getOperator() == ADD) { + if (node.operator() == MULTIPLY || node.operator() == ADD) { // if we have a operation of the form [+|*] , normalize it to // [+|*] - if (isConstant(node.getLeft()) && !isConstant(node.getRight())) { - node = new ArithmeticBinaryExpression( + if (isConstant(node.left()) && !isConstant(node.right())) { + node = new Arithmetic( plannerContext.getMetadata().resolveOperator( - switch (node.getOperator()) { + switch (node.operator()) { case ADD -> OperatorType.ADD; case MULTIPLY -> OperatorType.MULTIPLY; - default -> throw new IllegalStateException("Unexpected value: " + node.getOperator()); + default -> throw new IllegalStateException("Unexpected value: " + node.operator()); }, ImmutableList.of( - node.getFunction().getSignature().getArgumentType(1), - node.getFunction().getSignature().getArgumentType(0))), - node.getOperator(), - node.getRight(), - node.getLeft()); + node.function().getSignature().getArgumentType(1), + node.function().getSignature().getArgumentType(0))), + node.operator(), + node.right(), + node.left()); } } @@ -105,11 +105,11 @@ public Expression rewriteArithmeticBinary(ArithmeticBinaryExpression node, Void } @Override - public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter treeRewriter) { - CatalogSchemaFunctionName functionName = node.getFunction().getName(); - if (functionName.equals(builtinFunctionName("date")) && node.getArguments().size() == 1) { - Expression argument = node.getArguments().get(0); + CatalogSchemaFunctionName functionName = node.function().getName(); + if (functionName.equals(builtinFunctionName("date")) && node.arguments().size() == 1) { + Expression argument = node.arguments().get(0); Type argumentType = argument.type(); if (argumentType instanceof TimestampType || argumentType instanceof TimestampWithTimeZoneType diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java index 75e6ea4b852c..40e69c2aaec0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java @@ -20,8 +20,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -50,7 +50,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; @@ -117,7 +117,7 @@ public class DecorrelateInnerUnnestWithGlobalAggregation { private static final Pattern PATTERN = correlatedJoin() .with(nonEmpty(correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .matching(node -> node.getType() == JoinType.INNER || node.getType() == JoinType.LEFT); @Override @@ -193,7 +193,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co rewrittenUnnest, Assignments.builder() .putIdentities(rewrittenUnnest.getOutputSymbols()) - .put(mask, new NotExpression(new IsNullPredicate(ordinalitySymbol.toSymbolReference()))) + .put(mask, new Not(new IsNull(ordinalitySymbol.toSymbolReference()))) .build()); // restore all projections, grouped aggregations and global aggregations from the subquery diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java index 389f8daecf1c..7c36378533ac 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java @@ -39,7 +39,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -101,7 +101,7 @@ public class DecorrelateLeftUnnestWithGlobalAggregation { private static final Pattern PATTERN = correlatedJoin() .with(nonEmpty(correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .matching(node -> node.getType() == JoinType.INNER || node.getType() == JoinType.LEFT); @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java index 4942229c0525..a0e4fd9a4d73 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java @@ -21,10 +21,10 @@ import io.trino.matching.Pattern; import io.trino.metadata.Metadata; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.IsNullPredicate; +import io.trino.sql.ir.IsNull; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -56,9 +56,9 @@ import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.iterative.rule.ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning; @@ -146,7 +146,7 @@ public class DecorrelateUnnest { private static final Pattern PATTERN = correlatedJoin() .with(nonEmpty(correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .matching(node -> node.getType() == JoinType.INNER || node.getType() == JoinType.LEFT); private final Metadata metadata; @@ -260,7 +260,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co assignments.put( subquerySymbol, ifExpression( - new IsNullPredicate(ordinalitySymbol.toSymbolReference()), + new IsNull(ordinalitySymbol.toSymbolReference()), new Constant(subquerySymbol.getType(), null), subquerySymbol.toSymbolReference())); } @@ -401,14 +401,14 @@ public RewriteResult visitEnforceSingleRow(EnforceSingleRowNode node, Void conte Optional.empty()); } Expression predicate = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN, rowNumberSymbol.toSymbolReference(), new Constant(BIGINT, 1L)), new Cast( failFunction(metadata, SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), BOOLEAN), - TRUE_LITERAL); + TRUE); return new RewriteResult(new FilterNode(idAllocator.getNextId(), sourceNode, predicate), Optional.of(rowNumberSymbol)); } @@ -447,7 +447,7 @@ public RewriteResult visitLimit(LimitNode node, Void context) new FilterNode( idAllocator.getNextId(), sourceNode, - new ComparisonExpression(LESS_THAN_OR_EQUAL, rowNumberSymbol.toSymbolReference(), new Constant(BIGINT, node.getCount()))), + new Comparison(LESS_THAN_OR_EQUAL, rowNumberSymbol.toSymbolReference(), new Constant(BIGINT, node.getCount()))), Optional.of(rowNumberSymbol)); } @@ -476,7 +476,7 @@ public RewriteResult visitTopN(TopNNode node, Void context) new FilterNode( idAllocator.getNextId(), windowNode, - new ComparisonExpression(LESS_THAN_OR_EQUAL, rowNumberSymbol.toSymbolReference(), new Constant(BIGINT, node.getCount()))), + new Comparison(LESS_THAN_OR_EQUAL, rowNumberSymbol.toSymbolReference(), new Constant(BIGINT, node.getCount()))), Optional.of(rowNumberSymbol)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java index 6af4c0e26cc9..9e2bea55db37 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java @@ -17,9 +17,9 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import java.util.Collection; @@ -38,7 +38,7 @@ class DereferencePushdown { private DereferencePushdown() {} - public static Set extractRowSubscripts(Collection expressions, boolean allowOverlap) + public static Set extractRowSubscripts(Collection expressions, boolean allowOverlap) { Set symbolReferencesAndRowSubscripts = expressions.stream() .flatMap(expression -> getSymbolReferencesAndRowSubscripts(expression).stream()) @@ -54,28 +54,28 @@ public static Set extractRowSubscripts(Collection projections) { return projections.stream() - .allMatch(expression -> expression instanceof SymbolReference || - (expression instanceof SubscriptExpression && - isRowSubscriptChain((SubscriptExpression) expression) && + .allMatch(expression -> expression instanceof Reference || + (expression instanceof Subscript && + isRowSubscriptChain((Subscript) expression) && !prefixExists(expression, projections))); } - public static Symbol getBase(SubscriptExpression expression) + public static Symbol getBase(Subscript expression) { return getOnlyElement(extractAll(expression)); } /** - * Extract the sub-expressions of type {@link SubscriptExpression} or {@link SymbolReference} from the expression - * in a top-down manner. The expressions within the base of a valid {@link SubscriptExpression} sequence are not extracted. + * Extract the sub-expressions of type {@link Subscript} or {@link Reference} from the expression + * in a top-down manner. The expressions within the base of a valid {@link Subscript} sequence are not extracted. */ private static List getSymbolReferencesAndRowSubscripts(Expression expression) { @@ -84,7 +84,7 @@ private static List getSymbolReferencesAndRowSubscripts(Expression e new DefaultTraversalVisitor>() { @Override - protected Void visitSubscriptExpression(SubscriptExpression node, ImmutableList.Builder context) + protected Void visitSubscript(Subscript node, ImmutableList.Builder context) { if (isRowSubscriptChain(node)) { context.add(node); @@ -93,14 +93,14 @@ protected Void visitSubscriptExpression(SubscriptExpression node, ImmutableList. } @Override - protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder context) + protected Void visitReference(Reference node, ImmutableList.Builder context) { context.add(node); return null; } @Override - protected Void visitLambdaExpression(LambdaExpression node, ImmutableList.Builder context) + protected Void visitLambda(Lambda node, ImmutableList.Builder context) { return null; } @@ -109,27 +109,27 @@ protected Void visitLambdaExpression(LambdaExpression node, ImmutableList.Builde return builder.build(); } - private static boolean isRowSubscriptChain(SubscriptExpression expression) + private static boolean isRowSubscriptChain(Subscript expression) { - if (!(expression.getBase().type() instanceof RowType)) { + if (!(expression.base().type() instanceof RowType)) { return false; } - return (expression.getBase() instanceof SymbolReference) || - ((expression.getBase() instanceof SubscriptExpression) && isRowSubscriptChain((SubscriptExpression) expression.getBase())); + return (expression.base() instanceof Reference) || + ((expression.base() instanceof Subscript) && isRowSubscriptChain((Subscript) expression.base())); } private static boolean prefixExists(Expression expression, Set expressions) { Expression current = expression; - while (current instanceof SubscriptExpression) { - current = ((SubscriptExpression) current).getBase(); + while (current instanceof Subscript) { + current = ((Subscript) current).base(); if (expressions.contains(current)) { return true; } } - verify(current instanceof SymbolReference); + verify(current instanceof Reference); return false; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index e425524f806a..1bd409a8aadc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -304,7 +304,7 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context) Expression rewritten; if (row instanceof Row) { // preserve the structure of row - rewritten = new Row(((Row) row).getItems().stream() + rewritten = new Row(((Row) row).items().stream() .map(item -> rewriter.rewrite(item, context)) .collect(toImmutableList())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java index c8d8fd6b8f17..714baa82c6d9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java @@ -19,7 +19,7 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LogicalExpression; +import io.trino.sql.ir.Logical; import java.util.Collection; import java.util.List; @@ -28,7 +28,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.sql.ir.IrUtils.combinePredicates; import static io.trino.sql.ir.IrUtils.extractPredicates; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static java.util.Collections.emptySet; import static java.util.stream.Collectors.toList; @@ -57,29 +57,29 @@ protected Expression rewriteExpression(Expression node, NodeContext context, Exp } @Override - public Expression rewriteLogicalExpression(LogicalExpression node, NodeContext context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLogical(Logical node, NodeContext context, ExpressionTreeRewriter treeRewriter) { Expression expression = combinePredicates( - node.getOperator(), - extractPredicates(node.getOperator(), node).stream() + node.operator(), + extractPredicates(node.operator(), node).stream() .map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE)) .collect(toImmutableList())); - if (!(expression instanceof LogicalExpression)) { + if (!(expression instanceof Logical)) { return expression; } - Expression simplified = extractCommonPredicates((LogicalExpression) expression); + Expression simplified = extractCommonPredicates((Logical) expression); // Prefer AND LogicalBinaryExpression at the root if possible - if (context.isRootNode() && simplified instanceof LogicalExpression && ((LogicalExpression) simplified).getOperator() == OR) { - return distributeIfPossible((LogicalExpression) simplified); + if (context.isRootNode() && simplified instanceof Logical && ((Logical) simplified).operator() == OR) { + return distributeIfPossible((Logical) simplified); } return simplified; } - private Expression extractCommonPredicates(LogicalExpression node) + private Expression extractCommonPredicates(Logical node) { List> subPredicates = getSubPredicates(node); @@ -92,12 +92,12 @@ private Expression extractCommonPredicates(LogicalExpression node) .map(predicateList -> removeAll(predicateList, commonPredicates)) .collect(toImmutableList()); - LogicalExpression.Operator flippedOperator = node.getOperator().flip(); + Logical.Operator flippedOperator = node.operator().flip(); List uncorrelatedPredicates = uncorrelatedSubPredicates.stream() .map(predicate -> combinePredicates(flippedOperator, predicate)) .collect(toImmutableList()); - Expression combinedUncorrelatedPredicates = combinePredicates(node.getOperator(), uncorrelatedPredicates); + Expression combinedUncorrelatedPredicates = combinePredicates(node.operator(), uncorrelatedPredicates); return combinePredicates(flippedOperator, ImmutableList.builder() .addAll(commonPredicates) @@ -105,11 +105,11 @@ private Expression extractCommonPredicates(LogicalExpression node) .build()); } - private static List> getSubPredicates(LogicalExpression expression) + private static List> getSubPredicates(Logical expression) { - return extractPredicates(expression.getOperator(), expression).stream() - .map(predicate -> predicate instanceof LogicalExpression ? - extractPredicates((LogicalExpression) predicate) : ImmutableList.of(predicate)) + return extractPredicates(expression.operator(), expression).stream() + .map(predicate -> predicate instanceof Logical ? + extractPredicates((Logical) predicate) : ImmutableList.of(predicate)) .collect(toImmutableList()); } @@ -122,7 +122,7 @@ private static List> getSubPredicates(LogicalExpression express * Returns the original expression if the expression is non-deterministic or if the distribution will * expand the expression by too much. */ - private Expression distributeIfPossible(LogicalExpression expression) + private Expression distributeIfPossible(Logical expression) { if (!isDeterministic(expression)) { // Do not distribute boolean expressions if there are any non-deterministic elements @@ -159,9 +159,9 @@ private Expression distributeIfPossible(LogicalExpression expression) Set> crossProduct = Sets.cartesianProduct(subPredicates); return combinePredicates( - expression.getOperator().flip(), + expression.operator().flip(), crossProduct.stream() - .map(expressions -> combinePredicates(expression.getOperator(), expressions)) + .map(expressions -> combinePredicates(expression.operator(), expressions)) .collect(toImmutableList())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java index 6ff8c9a3447d..aa1b098147ae 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; @@ -78,13 +78,13 @@ public Pattern getPattern() @Override public Result apply(FilterNode node, Captures captures, Context context) { - Set dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true); + Set dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true); if (dereferences.isEmpty()) { return Result.empty(); } Assignments assignments = Assignments.of(dereferences, context.getSymbolAllocator()); - Map mappings = HashBiMap.create(assignments.getMap()) + Map mappings = HashBiMap.create(assignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java index 0e9fc9057cf5..a2d697478f26 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -46,12 +46,12 @@ import io.trino.split.SplitSource; import io.trino.split.SplitSource.SplitBatch; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.BuiltinFunctionCallBuilder; import io.trino.sql.planner.ResolvedFunctionCallBuilder; import io.trino.sql.planner.Symbol; @@ -86,8 +86,8 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.plan.JoinType.INNER; @@ -208,16 +208,16 @@ public Result apply(FilterNode node, Captures captures, Context context) { JoinNode joinNode = captures.get(JOIN); Expression filter = node.getPredicate(); - List spatialFunctions = extractSupportedSpatialFunctions(filter); - for (FunctionCall spatialFunction : spatialFunctions) { + List spatialFunctions = extractSupportedSpatialFunctions(filter); + for (Call spatialFunction : spatialFunctions) { Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), plannerContext, splitManager, pageSourceManager); if (!result.isEmpty()) { return result; } } - List spatialComparisons = extractSupportedSpatialComparisons(filter); - for (ComparisonExpression spatialComparison : spatialComparisons) { + List spatialComparisons = extractSupportedSpatialComparisons(filter); + for (Comparison spatialComparison : spatialComparisons) { Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, plannerContext, splitManager, pageSourceManager); if (!result.isEmpty()) { return result; @@ -261,16 +261,16 @@ public Pattern getPattern() public Result apply(JoinNode joinNode, Captures captures, Context context) { Expression filter = joinNode.getFilter().get(); - List spatialFunctions = extractSupportedSpatialFunctions(filter); - for (FunctionCall spatialFunction : spatialFunctions) { + List spatialFunctions = extractSupportedSpatialFunctions(filter); + for (Call spatialFunction : spatialFunctions) { Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), plannerContext, splitManager, pageSourceManager); if (!result.isEmpty()) { return result; } } - List spatialComparisons = extractSupportedSpatialComparisons(filter); - for (ComparisonExpression spatialComparison : spatialComparisons) { + List spatialComparisons = extractSupportedSpatialComparisons(filter); + for (Comparison spatialComparison : spatialComparisons) { Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, plannerContext, splitManager, pageSourceManager); if (!result.isEmpty()) { return result; @@ -287,7 +287,7 @@ private static Result tryCreateSpatialJoin( Expression filter, PlanNodeId nodeId, List outputSymbols, - ComparisonExpression spatialComparison, + Comparison spatialComparison, PlannerContext plannerContext, SplitManager splitManager, PageSourceManager pageSourceManager) @@ -300,14 +300,14 @@ private static Result tryCreateSpatialJoin( Expression radius; Optional newRadiusSymbol; - ComparisonExpression newComparison; - if (spatialComparison.getOperator() == LESS_THAN || spatialComparison.getOperator() == LESS_THAN_OR_EQUAL) { + Comparison newComparison; + if (spatialComparison.operator() == LESS_THAN || spatialComparison.operator() == LESS_THAN_OR_EQUAL) { // ST_Distance(a, b) <= r - radius = spatialComparison.getRight(); + radius = spatialComparison.right(); Set radiusSymbols = extractUnique(radius); if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { newRadiusSymbol = newRadiusSymbol(context, radius); - newComparison = new ComparisonExpression(spatialComparison.getOperator(), spatialComparison.getLeft(), toExpression(newRadiusSymbol, radius)); + newComparison = new Comparison(spatialComparison.operator(), spatialComparison.left(), toExpression(newRadiusSymbol, radius)); } else { return Result.empty(); @@ -315,11 +315,11 @@ private static Result tryCreateSpatialJoin( } else { // r >= ST_Distance(a, b) - radius = spatialComparison.getLeft(); + radius = spatialComparison.left(); Set radiusSymbols = extractUnique(radius); if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { newRadiusSymbol = newRadiusSymbol(context, radius); - newComparison = new ComparisonExpression(spatialComparison.getOperator().flip(), spatialComparison.getRight(), toExpression(newRadiusSymbol, radius)); + newComparison = new Comparison(spatialComparison.operator().flip(), spatialComparison.right(), toExpression(newRadiusSymbol, radius)); } else { return Result.empty(); @@ -346,7 +346,7 @@ private static Result tryCreateSpatialJoin( joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), plannerContext, splitManager, pageSourceManager); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (Call) newComparison.left(), Optional.of(newComparison.right()), plannerContext, splitManager, pageSourceManager); } private static Result tryCreateSpatialJoin( @@ -355,7 +355,7 @@ private static Result tryCreateSpatialJoin( Expression filter, PlanNodeId nodeId, List outputSymbols, - FunctionCall spatialFunction, + Call spatialFunction, Optional radius, PlannerContext plannerContext, SplitManager splitManager, @@ -365,7 +365,7 @@ private static Result tryCreateSpatialJoin( Optional spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty(); Optional kdbTree = spatialPartitioningTableName.map(tableName -> loadKdbTree(tableName, context.getSession(), plannerContext.getMetadata(), splitManager, pageSourceManager)); - List arguments = spatialFunction.getArguments(); + List arguments = spatialFunction.arguments(); verify(arguments.size() == 2); Expression firstArgument = arguments.get(0); @@ -426,7 +426,7 @@ else if (alignment < 0) { } } - ResolvedFunction resolvedFunction = spatialFunction.getFunction(); + ResolvedFunction resolvedFunction = spatialFunction.function(); Expression newSpatialFunction = ResolvedFunctionCallBuilder.builder(resolvedFunction) .addArgument(newFirstArgument) .addArgument(newSecondArgument) @@ -553,7 +553,7 @@ private static Expression toExpression(Optional optionalSymbol, Expressi private static Optional newGeometrySymbol(Context context, Expression expression, TypeManager typeManager) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return Optional.empty(); } @@ -562,7 +562,7 @@ private static Optional newGeometrySymbol(Context context, Expression ex private static Optional newRadiusSymbol(Context context, Expression expression) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return Optional.empty(); } @@ -593,7 +593,7 @@ private static PlanNode addPartitioningNodes(PlannerContext plannerContext, Cont .addArgument(typeSignature, new Cast(new Constant(VARCHAR, Slices.utf8Slice(KdbTreeUtils.toJson(kdbTree))), plannerContext.getTypeManager().getType(typeSignature))) .addArgument(GEOMETRY_TYPE_SIGNATURE, geometry); radius.ifPresent(value -> spatialPartitionsCall.addArgument(DOUBLE, value)); - FunctionCall partitioningFunction = spatialPartitionsCall.build(); + Call partitioningFunction = spatialPartitionsCall.build(); Symbol partitionsSymbol = context.getSymbolAllocator().newSymbol(partitioningFunction, new ArrayType(INTEGER)); projections.put(partitionsSymbol, partitioningFunction); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java index 05165810429d..02ce4f1f935b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java @@ -16,7 +16,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.Metadata; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.BuiltinFunctionCallBuilder; import io.trino.sql.planner.iterative.Rule; @@ -65,8 +65,8 @@ public Result apply(SampleNode sample, Captures captures, Context context) return Result.ofPlanNode(new FilterNode( sample.getId(), sample.getSource(), - new ComparisonExpression( - ComparisonExpression.Operator.LESS_THAN, + new Comparison( + Comparison.Operator.LESS_THAN, BuiltinFunctionCallBuilder.resolve(metadata) .setName("rand") .build(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java index 402064072f26..fc4888650c9e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java @@ -19,11 +19,11 @@ import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ExceptNode; @@ -33,8 +33,8 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.plan.Patterns.Except.distinct; import static io.trino.sql.planner.plan.Patterns.except; import static java.util.Objects.requireNonNull; @@ -101,10 +101,10 @@ public Result apply(ExceptNode node, Captures captures, Context context) Expression count = result.getCountSymbols().get(0).toSymbolReference(); for (int i = 1; i < result.getCountSymbols().size(); i++) { - count = new FunctionCall( + count = new Call( greatest, ImmutableList.of( - new ArithmeticBinaryExpression( + new Arithmetic( metadata.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(BIGINT, BIGINT)), SUBTRACT, count, @@ -113,7 +113,7 @@ public Result apply(ExceptNode node, Captures captures, Context context) } // filter rows so that expected number of rows remains - Expression removeExtraRows = new ComparisonExpression(LESS_THAN_OR_EQUAL, result.getRowNumberSymbol().toSymbolReference(), count); + Expression removeExtraRows = new Comparison(LESS_THAN_OR_EQUAL, result.getRowNumberSymbol().toSymbolReference(), count); FilterNode filter = new FilterNode( context.getIdAllocator().getNextId(), result.getPlanNode(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptDistinctAsUnion.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptDistinctAsUnion.java index db960f109a9f..b7ff33d85936 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptDistinctAsUnion.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptDistinctAsUnion.java @@ -17,7 +17,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.Metadata; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.iterative.Rule; @@ -27,8 +27,8 @@ import io.trino.sql.planner.plan.ProjectNode; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.plan.Patterns.Except.distinct; import static io.trino.sql.planner.plan.Patterns.except; @@ -90,9 +90,9 @@ public Result apply(ExceptNode node, Captures captures, Context context) // except predicate: the row must be present in the first source and absent in all the other sources ImmutableList.Builder predicatesBuilder = ImmutableList.builder(); - predicatesBuilder.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, result.getCountSymbols().get(0).toSymbolReference(), new Constant(BIGINT, 1L))); + predicatesBuilder.add(new Comparison(GREATER_THAN_OR_EQUAL, result.getCountSymbols().get(0).toSymbolReference(), new Constant(BIGINT, 1L))); for (int i = 1; i < node.getSources().size(); i++) { - predicatesBuilder.add(new ComparisonExpression(EQUAL, result.getCountSymbols().get(i).toSymbolReference(), new Constant(BIGINT, 0L))); + predicatesBuilder.add(new Comparison(EQUAL, result.getCountSymbols().get(i).toSymbolReference(), new Constant(BIGINT, 0L))); } return Result.ofPlanNode( new ProjectNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java index bc11a1d0c92d..64632fe12ff8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -30,7 +30,7 @@ import java.util.Optional; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.combineDisjunctsWithDefault; import static io.trino.sql.planner.plan.Patterns.aggregation; @@ -123,9 +123,9 @@ else if (mask.isPresent()) { mask)); } - Expression predicate = TRUE_LITERAL; + Expression predicate = TRUE; if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterOrMaskPresent) { - predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL); + predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE); } // identity projection for all existing inputs diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java index eea2eb025b09..58325d1b28f9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java @@ -18,9 +18,9 @@ import io.trino.matching.Pattern; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; @@ -30,7 +30,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.plan.Patterns.Intersect.distinct; import static io.trino.sql.planner.plan.Patterns.intersect; import static java.util.Objects.requireNonNull; @@ -97,11 +97,11 @@ public Result apply(IntersectNode node, Captures captures, Context context) Expression minCount = result.getCountSymbols().get(0).toSymbolReference(); for (int i = 1; i < result.getCountSymbols().size(); i++) { - minCount = new FunctionCall(least, ImmutableList.of(minCount, result.getCountSymbols().get(i).toSymbolReference())); + minCount = new Call(least, ImmutableList.of(minCount, result.getCountSymbols().get(i).toSymbolReference())); } // filter rows so that expected number of rows remains - Expression removeExtraRows = new ComparisonExpression(LESS_THAN_OR_EQUAL, result.getRowNumberSymbol().toSymbolReference(), minCount); + Expression removeExtraRows = new Comparison(LESS_THAN_OR_EQUAL, result.getRowNumberSymbol().toSymbolReference(), minCount); FilterNode filter = new FilterNode( context.getIdAllocator().getNextId(), result.getPlanNode(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectDistinctAsUnion.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectDistinctAsUnion.java index f24ad37f717c..7703e9345ccb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectDistinctAsUnion.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectDistinctAsUnion.java @@ -16,7 +16,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.Metadata; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.iterative.Rule; @@ -27,7 +27,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.plan.Patterns.Intersect.distinct; import static io.trino.sql.planner.plan.Patterns.intersect; @@ -89,7 +89,7 @@ public Result apply(IntersectNode node, Captures captures, Context context) // intersect predicate: the row must be present in every source Expression predicate = and(result.getCountSymbols().stream() - .map(symbol -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, symbol.toSymbolReference(), new Constant(BIGINT, 1L))) + .map(symbol -> new Comparison(GREATER_THAN_OR_EQUAL, symbol.toSymbolReference(), new Constant(BIGINT, 1L))) .collect(toImmutableList())); return Result.ofPlanNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java index 6fd18b71593b..ce8b8a246e03 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java @@ -21,7 +21,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.Metadata; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -41,7 +41,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.matching.Capture.newCapture; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.plan.Patterns.Limit.requiresPreSortedInputs; import static io.trino.sql.planner.plan.Patterns.limit; import static io.trino.sql.planner.plan.Patterns.source; @@ -138,7 +138,7 @@ public static PlanNode rewriteLimitWithTiesWithPartitioning(LimitNode limitNode, return new FilterNode( idAllocator.getNextId(), windowNode, - new ComparisonExpression( + new Comparison( LESS_THAN_OR_EQUAL, rankSymbol.toSymbolReference(), new Constant(BIGINT, limitNode.getCount()))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementOffset.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementOffset.java index dec889d86141..ef6b4ca9bf23 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementOffset.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementOffset.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -29,7 +29,7 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.plan.Patterns.offset; /** @@ -77,7 +77,7 @@ public Result apply(OffsetNode parent, Captures captures, Context context) FilterNode filterNode = new FilterNode( context.getIdAllocator().getNextId(), rowNumberNode, - new ComparisonExpression( + new Comparison( GREATER_THAN, rowNumberSymbol.toSymbolReference(), new Constant(BIGINT, parent.getCount()))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java index 02944ae9fe3b..d708a96de7ee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java @@ -22,12 +22,12 @@ import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.spi.type.Type; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -59,12 +59,12 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; import static io.trino.sql.ir.IrExpressions.ifExpression; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_FOLLOWING; import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_PRECEDING; import static io.trino.sql.planner.plan.JoinType.FULL; @@ -376,7 +376,7 @@ private static JoinedNodes copartition(NodeWithSymbols left, NodeWithSymbols rig List copartitionConjuncts = Streams.zip( leftPartitionBy.stream(), rightPartitionBy.stream(), - (leftColumn, rightColumn) -> new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, leftColumn, rightColumn))) + (leftColumn, rightColumn) -> new Not(new Comparison(IS_DISTINCT_FROM, leftColumn, rightColumn))) .collect(toImmutableList()); // Align matching partitions (co-partitions) from left and right source, according to row number. @@ -392,18 +392,18 @@ private static JoinedNodes copartition(NodeWithSymbols left, NodeWithSymbols rig // (R1 > S2 AND R2 = 1) // OR // (R2 > S1 AND R1 = 1)) - Expression joinCondition = new LogicalExpression( + Expression joinCondition = new Logical( AND, ImmutableList.builder() .addAll(copartitionConjuncts) - .add(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), - new ComparisonExpression(EQUAL, rightRowNumber, new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), - new ComparisonExpression(EQUAL, leftRowNumber, new Constant(BIGINT, 1L))))))) + .add(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, leftRowNumber, rightRowNumber), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, leftRowNumber, rightPartitionSize), + new Comparison(EQUAL, rightRowNumber, new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, rightRowNumber, leftPartitionSize), + new Comparison(EQUAL, leftRowNumber, new Constant(BIGINT, 1L))))))) .build()); // The join type depends on the prune when empty property of the sources. @@ -498,20 +498,20 @@ private static NodeWithSymbols appendHelperSymbolsForCopartitionedNodes( // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); Expression rowNumberExpression = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN, - new CoalesceExpression(leftRowNumber, new Constant(BIGINT, -1L)), - new CoalesceExpression(rightRowNumber, new Constant(BIGINT, -1L))), + new Coalesce(leftRowNumber, new Constant(BIGINT, -1L)), + new Coalesce(rightRowNumber, new Constant(BIGINT, -1L))), leftRowNumber, rightRowNumber); // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); Expression partitionSizeExpression = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN, - new CoalesceExpression(leftPartitionSize, new Constant(BIGINT, -1L)), - new CoalesceExpression(rightPartitionSize, new Constant(BIGINT, -1L))), + new Coalesce(leftPartitionSize, new Constant(BIGINT, -1L)), + new Coalesce(rightPartitionSize, new Constant(BIGINT, -1L))), leftPartitionSize, rightPartitionSize); @@ -526,7 +526,7 @@ private static NodeWithSymbols appendHelperSymbolsForCopartitionedNodes( Type type = leftColumn.getType(); Symbol joinedColumn = context.getSymbolAllocator().newSymbol("combined_partition_column", type); - joinedPartitionByAssignments.put(joinedColumn, new CoalesceExpression(leftColumn.toSymbolReference(), rightColumn.toSymbolReference())); + joinedPartitionByAssignments.put(joinedColumn, new Coalesce(leftColumn.toSymbolReference(), rightColumn.toSymbolReference())); joinedPartitionBy.add(joinedColumn); } @@ -567,14 +567,14 @@ private static JoinedNodes join(NodeWithSymbols left, NodeWithSymbols right, Con // (R1 > S2 AND R2 = 1) // OR // (R2 > S1 AND R1 = 1) - Expression joinCondition = new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), - new ComparisonExpression(EQUAL, rightRowNumber, new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), - new ComparisonExpression(EQUAL, leftRowNumber, new Constant(BIGINT, 1L)))))); + Expression joinCondition = new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, leftRowNumber, rightRowNumber), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, leftRowNumber, rightPartitionSize), + new Comparison(EQUAL, rightRowNumber, new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, rightRowNumber, leftPartitionSize), + new Comparison(EQUAL, leftRowNumber, new Constant(BIGINT, 1L)))))); JoinType joinType; if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { @@ -629,20 +629,20 @@ private static NodeWithSymbols appendHelperSymbolsForJoinedNodes(JoinedNodes joi // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); Expression rowNumberExpression = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN, - new CoalesceExpression(leftRowNumber, new Constant(BIGINT, -1L)), - new CoalesceExpression(rightRowNumber, new Constant(BIGINT, -1L))), + new Coalesce(leftRowNumber, new Constant(BIGINT, -1L)), + new Coalesce(rightRowNumber, new Constant(BIGINT, -1L))), leftRowNumber, rightRowNumber); // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); Expression partitionSizeExpression = ifExpression( - new ComparisonExpression( + new Comparison( GREATER_THAN, - new CoalesceExpression(leftPartitionSize, new Constant(BIGINT, -1L)), - new CoalesceExpression(rightPartitionSize, new Constant(BIGINT, -1L))), + new Coalesce(leftPartitionSize, new Constant(BIGINT, -1L)), + new Coalesce(rightPartitionSize, new Constant(BIGINT, -1L))), leftPartitionSize, rightPartitionSize); @@ -682,7 +682,7 @@ private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set sy symbolsToMarkers.put(symbol, marker); Expression actual = symbol.toSymbolReference(); Expression reference = referenceSymbol.toSymbolReference(); - assignments.put(marker, ifExpression(new ComparisonExpression(EQUAL, actual, reference), actual, new Constant(BIGINT, null))); + assignments.put(marker, ifExpression(new Comparison(EQUAL, actual, reference), actual, new Constant(BIGINT, null))); } PlanNode project = new ProjectNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java index 48f2562a0fb9..b475bda0710b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java @@ -19,7 +19,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.iterative.Rule; @@ -35,7 +35,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.matching.Capture.newCapture; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.planner.plan.Patterns.filter; @@ -108,7 +108,7 @@ public Result apply(FilterNode node, Captures captures, Context context) List filterConjuncts = extractConjuncts(node.getPredicate()); Map> conjuncts = filterConjuncts.stream() - .collect(partitioningBy(SymbolReference.class::isInstance)); + .collect(partitioningBy(Reference.class::isInstance)); List simpleConjuncts = conjuncts.get(true); List complexConjuncts = conjuncts.get(false); @@ -138,7 +138,7 @@ public Result apply(FilterNode node, Captures captures, Context context) for (Expression conjunct : filterConjuncts) { if (simpleConjunctsToInline.contains(conjunct)) { Expression expression = projectNode.getAssignments().get(Symbol.from(conjunct)); - if (expression == null || expression instanceof SymbolReference) { + if (expression == null || expression instanceof Reference) { // expression == null -> The symbol is not produced by the underlying projection (i.e. it is a correlation symbol). // expression instanceof SymbolReference -> Do not inline trivial projections. newConjuncts.add(conjunct); @@ -146,7 +146,7 @@ public Result apply(FilterNode node, Captures captures, Context context) else { newConjuncts.add(expression); newAssignments.putIdentities(SymbolsExtractor.extractUnique(expression)); - postFilterAssignmentsBuilder.put(Symbol.from(conjunct), TRUE_LITERAL); + postFilterAssignmentsBuilder.put(Symbol.from(conjunct), TRUE); } } else { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java index 19f6eed58ce4..6b4daa1e99d8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java @@ -22,8 +22,8 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.iterative.Rule; @@ -169,7 +169,7 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod // find references to simple constants or symbol references Set basicReferences = dependencies.keySet().stream() - .filter(input -> child.getAssignments().get(input) instanceof Constant || child.getAssignments().get(input) instanceof SymbolReference) + .filter(input -> child.getAssignments().get(input) instanceof Constant || child.getAssignments().get(input) instanceof Reference) .filter(input -> !child.getAssignments().isIdentity(input)) // skip identities, otherwise, this rule will keep firing forever .collect(toSet()); @@ -180,8 +180,8 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod // skip dereferences, otherwise, inlining can cause conflicts with PushdownDereferences Expression assignment = child.getAssignments().get(entry.getKey()); - if (assignment instanceof SubscriptExpression) { - if (((SubscriptExpression) assignment).getBase().type() instanceof RowType) { + if (assignment instanceof Subscript) { + if (((Subscript) assignment).base().type() instanceof RowType) { return false; } } @@ -196,6 +196,6 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod private static boolean isSymbolReference(Symbol symbol, Expression expression) { - return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()); + return expression instanceof Reference && ((Reference) expression).name().equals(symbol.getName()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java index 38ef63c35ff5..5da6b33fd5fb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java @@ -18,12 +18,12 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.trino.sql.ir.BindExpression; +import io.trino.sql.ir.Bind; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -56,13 +56,13 @@ public Visitor(SymbolAllocator symbolAllocator) } @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLambda(Lambda node, Context context, ExpressionTreeRewriter treeRewriter) { // Use linked hash set to guarantee deterministic iteration order LinkedHashSet referencedSymbols = new LinkedHashSet<>(); - Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedSymbols(referencedSymbols)); + Expression rewrittenBody = treeRewriter.rewrite(node.body(), context.withReferencedSymbols(referencedSymbols)); - List lambdaArguments = node.getArguments(); + List lambdaArguments = node.arguments(); Set captureSymbols = Sets.difference(referencedSymbols, ImmutableSet.copyOf(lambdaArguments)); @@ -76,18 +76,18 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Context context captureSymbolToExtraSymbol.put(captureSymbol, extraSymbol); newLambdaArguments.add(extraSymbol); } - newLambdaArguments.addAll(node.getArguments()); + newLambdaArguments.addAll(node.arguments()); ImmutableMap symbolsMap = captureSymbolToExtraSymbol.buildOrThrow(); Function symbolMapping = symbol -> symbolsMap.getOrDefault(symbol, symbol).toSymbolReference(); - LambdaExpression lambdaExpression = new LambdaExpression(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody)); + Lambda lambda = new Lambda(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody)); - Expression rewrittenExpression = lambdaExpression; + Expression rewrittenExpression = lambda; if (captureSymbols.size() != 0) { List capturedValues = captureSymbols.stream() - .map(symbol -> new SymbolReference(symbol.getType(), symbol.getName())) + .map(symbol -> new Reference(symbol.getType(), symbol.getName())) .collect(toImmutableList()); - rewrittenExpression = new BindExpression(capturedValues, lambdaExpression); + rewrittenExpression = new Bind(capturedValues, lambda); } context.getReferencedSymbols().addAll(captureSymbols); @@ -95,9 +95,9 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Context context } @Override - public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteReference(Reference node, Context context, ExpressionTreeRewriter treeRewriter) { - context.getReferencedSymbols().add(new Symbol(node.type(), node.getName())); + context.getReferencedSymbols().add(new Symbol(node.type(), node.name())); return null; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java index 0a4536c98045..049840473cef 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java @@ -21,8 +21,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.iterative.Rule; @@ -131,7 +131,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) for (Expression rowExpression : valuesNode.getRows().get()) { Row row = (Row) rowExpression; for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) { - if (!isDeterministic(row.getItems().get(i))) { + if (!isDeterministic(row.items().get(i))) { nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i)); } } @@ -150,7 +150,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) // inline values expressions into projection's assignments ImmutableList.Builder projectedRows = ImmutableList.builder(); for (Expression rowExpression : valuesNode.getRows().get()) { - Map mapping = buildMappings(valuesNode.getOutputSymbols(), (Row) rowExpression); + Map mapping = buildMappings(valuesNode.getOutputSymbols(), (Row) rowExpression); Row projectedRow = new Row(expressions.stream() .map(expression -> replaceExpression(expression, mapping)) .collect(toImmutableList())); @@ -164,11 +164,11 @@ private static boolean isSupportedValues(ValuesNode valuesNode) return valuesNode.getRows().isEmpty() || valuesNode.getRows().get().stream().allMatch(Row.class::isInstance); } - private Map buildMappings(List symbols, Row row) + private Map buildMappings(List symbols, Row row) { - ImmutableMap.Builder mappingBuilder = ImmutableMap.builder(); - for (int i = 0; i < row.getItems().size(); i++) { - mappingBuilder.put(symbols.get(i).toSymbolReference(), row.getItems().get(i)); + ImmutableMap.Builder mappingBuilder = ImmutableMap.builder(); + for (int i = 0; i < row.items().size(); i++) { + mappingBuilder.put(symbols.get(i).toSymbolReference(), row.items().get(i)); } return mappingBuilder.buildOrThrow(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java index cc2bf9cc99c7..e06b6389e9c1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java @@ -16,12 +16,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.LogicalExpression; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Logical; import java.util.Collection; import java.util.LinkedHashSet; @@ -30,10 +30,10 @@ import java.util.Set; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.or; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.AND; public final class NormalizeOrExpressionRewriter { @@ -48,35 +48,35 @@ private static class Visitor extends ExpressionRewriter { @Override - public Expression rewriteLogicalExpression(LogicalExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLogical(Logical node, Void context, ExpressionTreeRewriter treeRewriter) { - List terms = node.getTerms().stream() + List terms = node.terms().stream() .map(expression -> treeRewriter.rewrite(expression, context)) .collect(toImmutableList()); - if (node.getOperator() == AND) { + if (node.operator() == AND) { return and(terms); } - ImmutableList.Builder inPredicateBuilder = ImmutableList.builder(); + ImmutableList.Builder inPredicateBuilder = ImmutableList.builder(); ImmutableSet.Builder expressionToSkipBuilder = ImmutableSet.builder(); ImmutableList.Builder othersExpressionBuilder = ImmutableList.builder(); groupComparisonAndInPredicate(terms).forEach((expression, values) -> { if (values.size() > 1) { - inPredicateBuilder.add(new InPredicate(expression, mergeToInListExpression(values))); + inPredicateBuilder.add(new In(expression, mergeToInListExpression(values))); expressionToSkipBuilder.add(expression); } }); Set expressionToSkip = expressionToSkipBuilder.build(); for (Expression expression : terms) { - if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) { - if (!expressionToSkip.contains(comparisonExpression.getLeft())) { + if (expression instanceof Comparison comparison && comparison.operator() == EQUAL) { + if (!expressionToSkip.contains(comparison.left())) { othersExpressionBuilder.add(expression); } } - else if (expression instanceof InPredicate inPredicate) { - if (!expressionToSkip.contains(inPredicate.getValue())) { + else if (expression instanceof In in) { + if (!expressionToSkip.contains(in.value())) { othersExpressionBuilder.add(expression); } } @@ -95,11 +95,11 @@ private List mergeToInListExpression(Collection expressi { LinkedHashSet expressionValues = new LinkedHashSet<>(); for (Expression expression : expressions) { - if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) { - expressionValues.add(comparisonExpression.getRight()); + if (expression instanceof Comparison comparison && comparison.operator() == EQUAL) { + expressionValues.add(comparison.right()); } - else if (expression instanceof InPredicate inPredicate) { - expressionValues.addAll(inPredicate.getValueList()); + else if (expression instanceof In in) { + expressionValues.addAll(in.valueList()); } else { throw new IllegalStateException("Unexpected expression: " + expression); @@ -113,11 +113,11 @@ private Map> groupComparisonAndInPredicate(Li { ImmutableMultimap.Builder expressionBuilder = ImmutableMultimap.builder(); for (Expression expression : terms) { - if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) { - expressionBuilder.put(comparisonExpression.getLeft(), comparisonExpression); + if (expression instanceof Comparison comparison && comparison.operator() == EQUAL) { + expressionBuilder.put(comparison.left(), comparison); } - else if (expression instanceof InPredicate inPredicate) { - expressionBuilder.put(inPredicate.getValue(), inPredicate); + else if (expression instanceof In in) { + expressionBuilder.put(in.value(), in); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java index 96da45888ac1..d19d44ccb8f3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java @@ -34,7 +34,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.SystemSessionProperties.isOptimizeDuplicateInsensitiveJoins; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.plan.Patterns.aggregation; import static java.util.Objects.requireNonNull; @@ -134,7 +134,7 @@ public Optional visitJoin(JoinNode node, Void context) // LookupJoinOperator will evaluate non-deterministic condition on output rows until one of the // rows matches. Therefore it's safe to set maySkipOutputDuplicates for joins with non-deterministic // filters. - if (!isDeterministic(node.getFilter().orElse(TRUE_LITERAL))) { + if (!isDeterministic(node.getFilter().orElse(TRUE))) { if (node.isMaySkipOutputDuplicates()) { // join node is already set to skip duplicates, return empty to prevent rule from looping forever return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java index 6a358d9ad017..b17a5199adac 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java @@ -27,11 +27,11 @@ import io.trino.spi.type.Int128; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.Symbol; @@ -222,7 +222,7 @@ private ProjectNode createNewProjection( assignments.putIdentities(aggregationNode.getGroupingKeys()); newProjectionSymbols.forEach((aggregation, symbol) -> assignments.put( symbol, - new SearchedCaseExpression(ImmutableList.of( + new Case(ImmutableList.of( new WhenClause( aggregation.getOperand(), preAggregations.get(new PreAggregationKey(aggregation)).getAggregationSymbol().toSymbolReference())), @@ -298,7 +298,7 @@ private Map getPreAggregations(List> extractCaseAggregations(AggregationNode private Optional extractCaseAggregation(Symbol aggregationSymbol, Aggregation aggregation, ProjectNode projectNode, Context context) { if (aggregation.getArguments().size() != 1 - || !(aggregation.getArguments().get(0) instanceof SymbolReference) + || !(aggregation.getArguments().get(0) instanceof Reference) || aggregation.isDistinct() || aggregation.getFilter().isPresent() || aggregation.getMask().isPresent() @@ -352,17 +352,17 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo Expression unwrappedProjection; // unwrap top-level cast if (projection instanceof Cast) { - unwrappedProjection = ((Cast) projection).getExpression(); + unwrappedProjection = ((Cast) projection).expression(); } else { unwrappedProjection = projection; } - if (!(unwrappedProjection instanceof SearchedCaseExpression caseExpression)) { + if (!(unwrappedProjection instanceof Case caseExpression)) { return Optional.empty(); } - if (caseExpression.getWhenClauses().size() != 1) { + if (caseExpression.whenClauses().size() != 1) { return Optional.empty(); } @@ -382,9 +382,9 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo } Optional cumulativeAggregationDefaultValue = Optional.empty(); - if (caseExpression.getDefaultValue().isPresent()) { - Type defaultType = getType(caseExpression.getDefaultValue().get()); - Object defaultValue = optimizeExpression(caseExpression.getDefaultValue().get(), context); + if (caseExpression.defaultValue().isPresent()) { + Type defaultType = getType(caseExpression.defaultValue().get()); + Object defaultValue = optimizeExpression(caseExpression.defaultValue().get(), context); if (defaultValue != null) { if (!name.equals(SUM)) { return Optional.empty(); @@ -408,7 +408,7 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo } // cumulative aggregation default value need to be CAST to cumulative aggregation input type - cumulativeAggregationDefaultValue = Optional.of(new Cast(caseExpression.getDefaultValue().get(), aggregationType)); + cumulativeAggregationDefaultValue = Optional.of(new Cast(caseExpression.defaultValue().get(), aggregationType)); } return Optional.of(new CaseAggregation( @@ -416,8 +416,8 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo resolvedFunction, cumulativeFunction, name, - caseExpression.getWhenClauses().get(0).getOperand(), - caseExpression.getWhenClauses().get(0).getResult(), + caseExpression.whenClauses().get(0).getOperand(), + caseExpression.whenClauses().get(0).getResult(), cumulativeAggregationDefaultValue)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCorrelatedJoinColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCorrelatedJoinColumns.java index 9e98fda63791..0d50a6929e5b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCorrelatedJoinColumns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCorrelatedJoinColumns.java @@ -22,7 +22,7 @@ import java.util.Set; import static com.google.common.collect.Sets.intersection; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; @@ -86,7 +86,7 @@ protected Optional pushDownProjectOff(Context context, CorrelatedJoinN // remove unused correlated join node, retain input if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), referencedOutputs).isEmpty()) { // remove unused subquery of inner join - if (correlatedJoinNode.getType() == INNER && isScalar(subquery, context.getLookup()) && correlatedJoinNode.getFilter().equals(TRUE_LITERAL)) { + if (correlatedJoinNode.getType() == INNER && isScalar(subquery, context.getLookup()) && correlatedJoinNode.getFilter().equals(TRUE)) { return Optional.of(input); } // remove unused subquery of left join @@ -103,7 +103,7 @@ protected Optional pushDownProjectOff(Context context, CorrelatedJoinN // remove unused input node, retain subquery if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), referencedAndCorrelationSymbols).isEmpty()) { // remove unused input of inner join - if (correlatedJoinNode.getType() == INNER && isScalar(input, context.getLookup()) && correlatedJoinNode.getFilter().equals(TRUE_LITERAL)) { + if (correlatedJoinNode.getType() == INNER && isScalar(input, context.getLookup()) && correlatedJoinNode.getFilter().equals(TRUE)) { return Optional.of(subquery); } // remove unused input of right join diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneValuesColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneValuesColumns.java index a02107ab79a4..851cc12e4f56 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneValuesColumns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneValuesColumns.java @@ -68,7 +68,7 @@ protected Optional pushDownProjectOff(Context context, ValuesNode valu ImmutableList.Builder rowsBuilder = ImmutableList.builder(); for (Expression row : valuesNode.getRows().get()) { rowsBuilder.add(new Row(Arrays.stream(mapping) - .mapToObj(i -> ((Row) row).getItems().get(i)) + .mapToObj(i -> ((Row) row).items().get(i)) .collect(Collectors.toList()))); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 13f31990eef7..55d2a129671f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -32,7 +32,7 @@ import io.trino.sql.PlannerContext; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; @@ -102,7 +102,7 @@ private static boolean allArgumentsAreSimpleReferences(AggregationNode node) return node.getAggregations() .values().stream() .flatMap(aggregation -> aggregation.getArguments().stream()) - .allMatch(SymbolReference.class::isInstance); + .allMatch(Reference.class::isInstance); } private static boolean hasNoMasks(AggregationNode node) @@ -238,8 +238,8 @@ private static AggregateFunction toAggregateFunction(AggregationNode.Aggregation ImmutableList.Builder arguments = ImmutableList.builder(); for (int i = 0; i < aggregation.getArguments().size(); i++) { - SymbolReference argument = (SymbolReference) aggregation.getArguments().get(i); - arguments.add(new Variable(argument.getName(), signature.getArgumentTypes().get(i))); + Reference argument = (Reference) aggregation.getArguments().get(i); + arguments.add(new Variable(argument.name(), signature.getArgumentTypes().get(i))); } Optional orderingScheme = aggregation.getOrderingScheme(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index f314378214d3..40525ca3b2c0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -21,7 +21,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.spi.type.Type; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Row; @@ -260,7 +260,7 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati Assignments.Builder assignmentsBuilder = Assignments.builder(); for (Symbol symbol : outerJoin.getOutputSymbols()) { if (aggregationNode.getAggregations().containsKey(symbol)) { - assignmentsBuilder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); + assignmentsBuilder.put(symbol, new Coalesce(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); } else { assignmentsBuilder.putIdentity(symbol); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java index 64834bfbe6ef..3cd9aa2ce1c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java @@ -62,7 +62,7 @@ private static class Rewriter @Override public Expression rewriteCast(Cast node, Boolean inRowCast, ExpressionTreeRewriter treeRewriter) { - if (!(node.getType() instanceof RowType type)) { + if (!(node.type() instanceof RowType type)) { return treeRewriter.defaultRewrite(node, false); } @@ -70,15 +70,15 @@ public Expression rewriteCast(Cast node, Boolean inRowCast, ExpressionTreeRewrit // otherwise, apply recursively with inRowCast == true and don't push this one if (inRowCast || type.getFields().stream().allMatch(field -> field.getName().isEmpty())) { - Expression value = treeRewriter.rewrite(node.getExpression(), true); + Expression value = treeRewriter.rewrite(node.expression(), true); if (value instanceof Row row) { ImmutableList.Builder items = ImmutableList.builder(); - for (int i = 0; i < row.getItems().size(); i++) { - Expression item = row.getItems().get(i); + for (int i = 0; i < row.items().size(); i++) { + Expression item = row.items().get(i); Type itemType = type.getFields().get(i).getType(); if (!(itemType instanceof UnknownType)) { - item = new Cast(item, itemType, node.isSafe()); + item = new Cast(item, itemType, node.safe()); } items.add(item); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java index b2f42520bd32..5c23f7d0bd88 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; @@ -80,7 +80,7 @@ public Result apply(ProjectNode node, Captures captures, Rule.Context context) .build(); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(expressions, false); + Set dereferences = extractRowSubscripts(expressions, false); if (dereferences.isEmpty()) { return Result.empty(); @@ -90,7 +90,7 @@ public Result apply(ProjectNode node, Captures captures, Rule.Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java index a70fd2fb6583..63afe746a73a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java @@ -20,8 +20,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -92,7 +92,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); joinNode.getFilter().ifPresent(expressionsBuilder::add); - Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); + Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); // Exclude criteria symbols ImmutableSet.Builder criteriaSymbolsBuilder = ImmutableSet.builder(); @@ -114,7 +114,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java index 8b9a7abd5f94..d97bb3cf3bba 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java @@ -18,8 +18,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -66,7 +66,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) ProjectNode child = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(node.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(node.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being synthesized within child dereferences = dereferences.stream() @@ -81,7 +81,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java index 6f4a7be2342c..f212b553df96 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.PlanNode; @@ -78,7 +78,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) SemiJoinNode semiJoinNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // All dereferences can be assumed on the symbols coming from source, since filteringSource output is not propagated, // and semiJoinOutput is of type boolean. We exclude pushdown of dereferences on sourceJoinSymbol. @@ -94,7 +94,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java index ad015da48c9d..3f0f0bbe7641 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; @@ -79,7 +79,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); // Extract dereferences for pushdown - Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); + Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); // Only retain dereferences on replicate symbols dereferences = dereferences.stream() @@ -94,7 +94,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java index 9538e09048b0..33df63435a95 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; @@ -70,7 +70,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) AssignUniqueId assignUniqueId = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // We do not need to filter dereferences on idColumn symbol since it is supposed to be of BIGINT type. @@ -82,7 +82,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java index 2c0cf12e017e..804864696895 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java @@ -20,8 +20,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) LimitNode limitNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in tiesResolvingScheme and requiresPreSortedInputs Set excludedSymbols = ImmutableSet.builder() @@ -96,7 +96,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java index 954af4add92b..6da535d17790 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.MarkDistinctNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) MarkDistinctNode markDistinctNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on distinct symbols being used in markDistinctNode. We do not need to filter // dereferences on markerSymbol since it is supposed to be of boolean type. @@ -91,7 +91,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java index 22c07311fc87..6b9e01dd1220 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) RowNumberNode rowNumberNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in partitionBy dereferences = dereferences.stream() @@ -90,7 +90,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java index 3d1c41725545..26a90cd005b0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) SortNode sortNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols used in ordering scheme to avoid replication of data dereferences = dereferences.stream() @@ -90,7 +90,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java index f94ecbe5c7fd..b724bd7a7948 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) TopNNode topNNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in orderBy dereferences = dereferences.stream() @@ -90,7 +90,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java index 48bb0857f93a..eebc12f73756 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -78,7 +78,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) TopNRankingNode topNRankingNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in partitionBy and orderBy DataOrganizationSpecification specification = topNRankingNode.getSpecification(); @@ -98,7 +98,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java index 7d7f36d2044a..92d1861abe04 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -79,7 +79,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) WindowNode windowNode = captures.get(CHILD); // Extract dereferences for pushdown - Set dereferences = extractRowSubscripts( + Set dereferences = extractRowSubscripts( ImmutableList.builder() .addAll(projectNode.getAssignments().getExpressions()) // also include dereference projections used in window functions @@ -108,7 +108,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator()); // Rewrite project node assignments using new symbols for dereference expressions - Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) .inverse() .entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownNegationsExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownNegationsExpressionRewriter.java index b7a92d146edb..609cfb7475f1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownNegationsExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownNegationsExpressionRewriter.java @@ -16,22 +16,22 @@ import io.trino.spi.type.DoubleType; import io.trino.spi.type.RealType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.ComparisonExpression.Operator; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Comparison.Operator; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; import java.util.List; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.combinePredicates; import static io.trino.sql.ir.IrUtils.extractPredicates; @@ -48,17 +48,17 @@ private static class Visitor extends ExpressionRewriter { @Override - public Expression rewriteNotExpression(NotExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteNot(Not node, Void context, ExpressionTreeRewriter treeRewriter) { - if (node.getValue() instanceof LogicalExpression child) { + if (node.value() instanceof Logical child) { List predicates = extractPredicates(child); - List negatedPredicates = predicates.stream().map(predicate -> treeRewriter.rewrite((Expression) new NotExpression(predicate), context)).collect(toImmutableList()); - return combinePredicates(child.getOperator().flip(), negatedPredicates); + List negatedPredicates = predicates.stream().map(predicate -> treeRewriter.rewrite((Expression) new Not(predicate), context)).collect(toImmutableList()); + return combinePredicates(child.operator().flip(), negatedPredicates); } - if (node.getValue() instanceof ComparisonExpression child && child.getOperator() != IS_DISTINCT_FROM) { - Operator operator = child.getOperator(); - Expression left = child.getLeft(); - Expression right = child.getRight(); + if (node.value() instanceof Comparison child && child.operator() != IS_DISTINCT_FROM) { + Operator operator = child.operator(); + Expression left = child.left(); + Expression right = child.right(); Type leftType = left.type(); Type rightType = right.type(); if ((typeHasNaN(leftType) || typeHasNaN(rightType)) && ( @@ -66,15 +66,15 @@ public Expression rewriteNotExpression(NotExpression node, Void context, Express operator == GREATER_THAN || operator == LESS_THAN_OR_EQUAL || operator == LESS_THAN)) { - return new NotExpression(new ComparisonExpression(operator, treeRewriter.rewrite(left, context), treeRewriter.rewrite(right, context))); + return new Not(new Comparison(operator, treeRewriter.rewrite(left, context), treeRewriter.rewrite(right, context))); } - return new ComparisonExpression(operator.negate(), treeRewriter.rewrite(left, context), treeRewriter.rewrite(right, context)); + return new Comparison(operator.negate(), treeRewriter.rewrite(left, context), treeRewriter.rewrite(right, context)); } - if (node.getValue() instanceof NotExpression child) { - return treeRewriter.rewrite(child.getValue(), context); + if (node.value() instanceof Not child) { + return treeRewriter.rewrite(child.value(), context); } - return new NotExpression(treeRewriter.rewrite(node.getValue(), context)); + return new Not(treeRewriter.rewrite(node.value(), context)); } private boolean typeHasNaN(Type type) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownProjectionsFromPatternRecognition.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownProjectionsFromPatternRecognition.java index 7c06638af9ea..a3c5a5a7cef3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownProjectionsFromPatternRecognition.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownProjectionsFromPatternRecognition.java @@ -19,7 +19,7 @@ import io.trino.matching.Pattern; import io.trino.spi.type.Type; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.iterative.Rule; @@ -147,7 +147,7 @@ private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers exp ImmutableList.Builder rewrittenArguments = ImmutableList.builder(); for (int i = 0; i < pointer.getArguments().size(); i++) { Expression argument = pointer.getArguments().get(i); - if (argument instanceof SymbolReference || SymbolsExtractor.extractUnique(argument).stream() + if (argument instanceof Reference || SymbolsExtractor.extractUnique(argument).stream() .anyMatch(runtimeEvaluatedSymbols::contains)) { rewrittenArguments.add(argument); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java index 799161b2886d..595986f52e36 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java @@ -46,7 +46,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.matching.Capture.newCapture; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.DomainTranslator.getExtractionResult; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -231,7 +231,7 @@ private static Result pushFilter(FilterNode filterNode, AggregationNode aggregat Expression newPredicate = combineConjuncts( new DomainTranslator().toPredicate(newTupleDomain), extractionResult.getRemainingExpression()); - if (newPredicate.equals(TRUE_LITERAL)) { + if (newPredicate.equals(TRUE)) { return Result.ofPlanNode(filterSource); } return Result.ofPlanNode(new FilterNode(filterNode.getId(), filterSource, newPredicate)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java index d95fb5614f33..d77de7b78bb2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java @@ -19,9 +19,9 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; @@ -37,11 +37,11 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.matching.Capture.newCapture; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; @@ -78,7 +78,7 @@ */ public class PushInequalityFilterExpressionBelowJoinRuleSet { - private static final Set SUPPORTED_COMPARISONS = ImmutableSet.of( + private static final Set SUPPORTED_COMPARISONS = ImmutableSet.of( GREATER_THAN, GREATER_THAN_OR_EQUAL, LESS_THAN, @@ -109,7 +109,7 @@ private Result pushInequalityFilterExpressionBelowJoin(Context context, JoinNode { JoinNodeContext joinNodeContext = new JoinNodeContext(joinNode); - Expression parentFilterPredicate = filterNode.map(FilterNode::getPredicate).orElse(TRUE_LITERAL); + Expression parentFilterPredicate = filterNode.map(FilterNode::getPredicate).orElse(TRUE); Map> parentFilterCandidates; if (joinNode.getType() == INNER) { parentFilterCandidates = extractPushDownCandidates(joinNodeContext, parentFilterPredicate); @@ -123,7 +123,7 @@ private Result pushInequalityFilterExpressionBelowJoin(Context context, JoinNode false, extractConjuncts(parentFilterPredicate)); } - Map> joinFilterCandidates = extractPushDownCandidates(joinNodeContext, joinNode.getFilter().orElse(TRUE_LITERAL)); + Map> joinFilterCandidates = extractPushDownCandidates(joinNodeContext, joinNode.getFilter().orElse(TRUE)); if (parentFilterCandidates.get(true).isEmpty() && joinFilterCandidates.get(true).isEmpty()) { // no push-down candidates @@ -160,7 +160,7 @@ private Result pushInequalityFilterExpressionBelowJoin(Context context, JoinNode private Optional conjunctsToFilter(List conjuncts) { - return Optional.of(combineConjuncts(conjuncts)).filter(expression -> !TRUE_LITERAL.equals(expression)); + return Optional.of(combineConjuncts(conjuncts)).filter(expression -> !TRUE.equals(expression)); } Map> extractPushDownCandidates(JoinNodeContext joinNodeContext, Expression filter) @@ -171,14 +171,14 @@ Map> extractPushDownCandidates(JoinNodeContext joinNod private boolean isSupportedExpression(JoinNodeContext joinNodeContext, Expression expression) { - if (!(expression instanceof ComparisonExpression comparison && isDeterministic(expression))) { + if (!(expression instanceof Comparison comparison && isDeterministic(expression))) { return false; } - if (!SUPPORTED_COMPARISONS.contains(comparison.getOperator())) { + if (!SUPPORTED_COMPARISONS.contains(comparison.operator())) { return false; } - Set leftComparisonSymbols = extractUnique(comparison.getLeft()); - Set rightComparisonSymbols = extractUnique(comparison.getRight()); + Set leftComparisonSymbols = extractUnique(comparison.left()); + Set rightComparisonSymbols = extractUnique(comparison.right()); if (leftComparisonSymbols.isEmpty() || rightComparisonSymbols.isEmpty()) { return false; } @@ -190,10 +190,10 @@ private boolean isSupportedExpression(JoinNodeContext joinNodeContext, Expressio } boolean alignedComparison = joinNodeContext.isComparisonAligned(comparison); - Expression buildExpression = alignedComparison ? comparison.getRight() : comparison.getLeft(); + Expression buildExpression = alignedComparison ? comparison.right() : comparison.left(); // if buildExpression is a symbol, and it is available, we don't need to push down anything - return !(buildExpression instanceof SymbolReference); + return !(buildExpression instanceof Reference); } Map pushDownRightComplexExpressions( @@ -214,14 +214,14 @@ private void pushDownRightComplexExpression( ImmutableMap.Builder newProjections, Expression conjunct) { - checkArgument(conjunct instanceof ComparisonExpression, "conjunct '%s' is not a comparison", conjunct); - ComparisonExpression comparison = (ComparisonExpression) conjunct; + checkArgument(conjunct instanceof Comparison, "conjunct '%s' is not a comparison", conjunct); + Comparison comparison = (Comparison) conjunct; boolean alignedComparison = joinNodeContext.isComparisonAligned(comparison); - Expression rightExpression = alignedComparison ? comparison.getRight() : comparison.getLeft(); - Expression leftExpression = alignedComparison ? comparison.getLeft() : comparison.getRight(); + Expression rightExpression = alignedComparison ? comparison.right() : comparison.left(); + Expression leftExpression = alignedComparison ? comparison.left() : comparison.right(); Symbol rightSymbol = symbolForExpression(context, rightExpression); - newConjuncts.add(new ComparisonExpression( - comparison.getOperator(), + newConjuncts.add(new Comparison( + comparison.operator(), alignedComparison ? leftExpression : rightSymbol.toSymbolReference(), alignedComparison ? rightSymbol.toSymbolReference() : leftExpression)); newProjections.put(rightSymbol, rightExpression); @@ -281,7 +281,7 @@ private Assignments buildAssignments(PlanNode source, Map ne private Symbol symbolForExpression(Context context, Expression expression) { - checkArgument(!(expression instanceof SymbolReference), "expression '%s' is a SymbolReference", expression); + checkArgument(!(expression instanceof Reference), "expression '%s' is a SymbolReference", expression); return context.getSymbolAllocator().newSymbol(expression, expression.type()); } @@ -339,9 +339,9 @@ public Set getRightSymbols() return rightSymbols; } - public boolean isComparisonAligned(ComparisonExpression comparison) + public boolean isComparisonAligned(Comparison comparison) { - return leftSymbols.containsAll(extractUnique(comparison.getLeft())); + return leftSymbols.containsAll(extractUnique(comparison.left())); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java index 964500db586a..23cba4fff6ef 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java @@ -29,7 +29,7 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation; @@ -110,7 +110,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) context.getSession(), effectiveFilter); - if (!translation.remainingExpression().equals(BooleanLiteral.TRUE_LITERAL)) { + if (!translation.remainingExpression().equals(Booleans.TRUE)) { // TODO add extra filter node above join return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java index 4d2c34b56d2e..88433ff9ace6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java @@ -19,7 +19,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -83,7 +83,7 @@ && exclusiveDereferences(projections)) { for (Symbol symbol : symbolsForRewrite) { Expression expression = projectNode.getAssignments().get(symbol); // if a symbol results from some computation, the translation fails - if (!(expression instanceof SymbolReference)) { + if (!(expression instanceof Reference)) { return Result.empty(); } symbolMapper.put(symbol, Symbol.from(expression)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushMergeWriterUpdateIntoConnector.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushMergeWriterUpdateIntoConnector.java index acd5b88f73e2..14bb72439d90 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushMergeWriterUpdateIntoConnector.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushMergeWriterUpdateIntoConnector.java @@ -25,8 +25,8 @@ import io.trino.spi.type.Type; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.MergeProcessorNode; @@ -117,11 +117,11 @@ private Map buildAssignments( { ImmutableMap.Builder assignments = ImmutableMap.builder(); if (mergeRow instanceof Row row) { - List fields = row.getChildren(); + List fields = row.children(); for (int i = 0; i < orderedColumnNames.size(); i++) { String columnName = orderedColumnNames.get(i); Expression field = fields.get(i); - if (field instanceof SymbolReference) { + if (field instanceof Reference) { // the column is not updated continue; } @@ -138,8 +138,8 @@ private Map buildAssignments( } } else if (mergeRow instanceof Constant row) { - RowType type = (RowType) row.getType(); - SqlRow rowValue = (SqlRow) row.getValue(); + RowType type = (RowType) row.type(); + SqlRow rowValue = (SqlRow) row.value(); for (int i = 0; i < orderedColumnNames.size(); i++) { Type fieldType = type.getFields().get(i).getType(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 5542f56d7ee5..fecf9c57508f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -23,7 +23,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LambdaExpression; +import io.trino.sql.ir.Lambda; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.Symbol; @@ -225,7 +225,7 @@ private PlanNode split(AggregationNode node, Context context) ImmutableList.builder() .add(intermediateSymbol.toSymbolReference()) .addAll(originalAggregation.getArguments().stream() - .filter(LambdaExpression.class::isInstance) + .filter(Lambda.class::isInstance) .collect(toImmutableList())) .build(), false, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 702e2745025b..201d7c04cec6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -32,6 +32,7 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.ConnectorExpressionTranslator; @@ -60,9 +61,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; import static io.trino.matching.Capture.newCapture; -import static io.trino.spi.expression.Constant.TRUE; import static io.trino.sql.DynamicFilters.isDynamicFilter; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; @@ -180,7 +179,7 @@ public static Optional pushFilterIntoTableScan( Constraint constraint; // use evaluator only when there is some predicate which could not be translated into tuple domain - if (pruneWithPredicateExpression && !TRUE_LITERAL.equals(decomposedPredicate.getRemainingExpression())) { + if (pruneWithPredicateExpression && !Booleans.TRUE.equals(decomposedPredicate.getRemainingExpression())) { LayoutConstraintEvaluator evaluator = new LayoutConstraintEvaluator( plannerContext, session, @@ -201,17 +200,17 @@ public static Optional pushFilterIntoTableScan( // check if new domain is wider than domain already provided by table scan if (constraint.predicate().isEmpty() && // TODO do we need to track enforced ConnectorExpression in TableScanNode? - TRUE.equals(expressionTranslation.connectorExpression()) && + io.trino.spi.expression.Constant.TRUE.equals(expressionTranslation.connectorExpression()) && newDomain.contains(node.getEnforcedConstraint())) { Expression resultingPredicate = createResultingPredicate( plannerContext, session, splitExpression.getDynamicFilter(), - TRUE_LITERAL, + Booleans.TRUE, splitExpression.getNonDeterministicPredicate(), decomposedPredicate.getRemainingExpression()); - if (!TRUE_LITERAL.equals(resultingPredicate)) { + if (!Booleans.TRUE.equals(resultingPredicate)) { return Optional.of(new FilterNode(filterNode.getId(), node, resultingPredicate)); } @@ -284,7 +283,7 @@ public static Optional pushFilterIntoTableScan( splitExpression.getNonDeterministicPredicate(), remainingDecomposedPredicate); - if (!TRUE_LITERAL.equals(resultingPredicate)) { + if (!Booleans.TRUE.equals(resultingPredicate)) { return Optional.of(new FilterNode(filterNode.getId(), tableScan, resultingPredicate)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java index 86e858762851..e28834e5d9b2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java @@ -37,7 +37,7 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.spi.predicate.Range.range; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.plan.Patterns.filter; import static io.trino.sql.planner.plan.Patterns.project; @@ -140,7 +140,7 @@ public Result apply(FilterNode filter, Captures captures, Context context) Expression newPredicate = combineConjuncts( extractionResult.getRemainingExpression(), new DomainTranslator().toPredicate(newTupleDomain)); - if (newPredicate.equals(TRUE_LITERAL)) { + if (newPredicate.equals(TRUE)) { return Result.ofPlanNode(project); } return Result.ofPlanNode(new FilterNode(filter.getId(), project, newPredicate)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java index 7b8d4732205c..d6a8aa1cd93f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java @@ -41,7 +41,7 @@ import static io.trino.SystemSessionProperties.isOptimizeTopNRanking; import static io.trino.matching.Capture.newCapture; import static io.trino.spi.predicate.Range.range; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.iterative.rule.Util.toTopNRankingType; import static io.trino.sql.planner.plan.Patterns.filter; @@ -146,7 +146,7 @@ public Result apply(FilterNode filter, Captures captures, Context context) Expression newPredicate = combineConjuncts( extractionResult.getRemainingExpression(), new DomainTranslator().toPredicate(newTupleDomain)); - if (newPredicate.equals(TRUE_LITERAL)) { + if (newPredicate.equals(TRUE)) { return Result.ofPlanNode(project); } return Result.ofPlanNode(new FilterNode(filter.getId(), project, newPredicate)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 34965c399c1d..4bbfaf5f481e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -21,7 +21,7 @@ import io.trino.matching.Pattern; import io.trino.spi.type.Type; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -190,7 +190,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) private static boolean isSymbolToSymbolProjection(ProjectNode project) { - return project.getAssignments().getExpressions().stream().allMatch(SymbolReference.class::isInstance); + return project.getAssignments().getExpressions().stream().allMatch(Reference.class::isInstance); } private static Map mapExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughUnion.java index cc215ba5bd29..c89bbfc8216b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -20,7 +20,7 @@ import io.trino.matching.Pattern; import io.trino.spi.type.Type; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; @@ -68,7 +68,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) ImmutableList.Builder outputSources = ImmutableList.builder(); for (int i = 0; i < source.getSources().size(); i++) { - Map outputToInput = source.sourceSymbolMap(i); // Map: output of union -> input of this source to the union + Map outputToInput = source.sourceSymbolMap(i); // Map: output of union -> input of this source to the union Assignments.Builder assignments = Assignments.builder(); // assignments for the new ProjectNode // mapping from current ProjectNode to new ProjectNode, used to identify the output layout @@ -93,6 +93,6 @@ private static boolean nonTrivialProjection(ProjectNode project) { return !project.getAssignments() .getExpressions().stream() - .allMatch(SymbolReference.class::isInstance); + .allMatch(Reference.class::isInstance); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java index 99bf57402c93..9fbb5361bdb5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java @@ -19,7 +19,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.optimizations.SymbolMapper; @@ -113,7 +113,7 @@ private Optional symbolMapper(List symbols, Assignments as SymbolMapper.Builder mapper = SymbolMapper.builder(); for (Symbol symbol : symbols) { Expression expression = assignments.get(symbol); - if (!(expression instanceof SymbolReference)) { + if (!(expression instanceof Reference)) { return Optional.empty(); } mapper.put(symbol, Symbol.from(expression)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java index 7265c20d0cb8..3bc53c8b1cdc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java @@ -23,7 +23,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.Symbol; @@ -103,7 +103,7 @@ public Result apply(FilterNode node, Captures captures, Context context) extractionResult.getRemainingExpression(), new DomainTranslator().toPredicate(newTupleDomain)); - if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { + if (newPredicate.equals(Booleans.TRUE)) { return Result.ofPlanNode(source); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java index 092ccfb50e3b..9c29bfaca5e4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java @@ -23,7 +23,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.Symbol; @@ -123,7 +123,7 @@ public Result apply(FilterNode node, Captures captures, Context context) extractionResult.getRemainingExpression(), new DomainTranslator().toPredicate(newTupleDomain)); - if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { + if (newPredicate.equals(Booleans.TRUE)) { return Result.ofPlanNode(newSource); } return Result.ofPlanNode(new FilterNode(node.getId(), newSource, newPredicate)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveDuplicateConditions.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveDuplicateConditions.java index d1aa811a1726..1b88f52fe0a9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveDuplicateConditions.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveDuplicateConditions.java @@ -15,7 +15,7 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LogicalExpression; +import io.trino.sql.ir.Logical; import static io.trino.sql.ir.IrUtils.combinePredicates; import static io.trino.sql.ir.IrUtils.extractPredicates; @@ -46,9 +46,9 @@ private static class Visitor extends io.trino.sql.ir.ExpressionRewriter { @Override - public Expression rewriteLogicalExpression(LogicalExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLogical(Logical node, Void context, ExpressionTreeRewriter treeRewriter) { - return combinePredicates(node.getOperator(), extractPredicates(node)); + return combinePredicates(node.operator(), extractPredicates(node)); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java index 904d08ba790f..c0914662c180 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java @@ -18,11 +18,11 @@ import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.VarcharType; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; @@ -44,7 +44,7 @@ private static Expression rewrite(Expression expression, Session session, Planne { requireNonNull(plannerContext, "plannerContext is null"); - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return expression; } return ExpressionTreeRewriter.rewriteWith(new Visitor(session, plannerContext), expression); @@ -63,12 +63,12 @@ public Visitor(Session session, PlannerContext plannerContext) } @Override - public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter treeRewriter) { - CatalogSchemaFunctionName functionName = node.getFunction().getName(); - if (functionName.equals(builtinFunctionName("date_trunc")) && node.getArguments().size() == 2) { - Expression unitExpression = node.getArguments().get(0); - Expression argument = node.getArguments().get(1); + CatalogSchemaFunctionName functionName = node.function().getName(); + if (functionName.equals(builtinFunctionName("date_trunc")) && node.arguments().size() == 2) { + Expression unitExpression = node.arguments().get(0); + Expression argument = node.arguments().get(1); if (argument.type() == DATE && unitExpression.type() instanceof VarcharType && unitExpression instanceof Constant) { Slice unitValue = (Slice) new IrExpressionInterpreter(unitExpression, plannerContext, session) .optimize(NoOpSymbolResolver.INSTANCE); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantExists.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantExists.java index 6d8f1f0f50d7..5b018719f9c0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantExists.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantExists.java @@ -23,8 +23,8 @@ import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; import static io.trino.sql.planner.plan.Patterns.applyNode; @@ -80,10 +80,10 @@ public Result apply(ApplyNode node, Captures captures, Context context) Cardinality subqueryCardinality = extractCardinality(node.getSubquery(), context.getLookup()); Expression result; if (subqueryCardinality.isEmpty()) { - result = FALSE_LITERAL; + result = FALSE; } else if (subqueryCardinality.isAtLeastScalar()) { - result = TRUE_LITERAL; + result = TRUE; } else { return Result.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java index 4f7d104cdb00..ab4875a1e177 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java @@ -24,6 +24,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.DomainTranslator.ExtractionResult; @@ -40,7 +41,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.matching.Capture.newCapture; import static io.trino.spi.predicate.TupleDomain.intersect; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.ir.IrUtils.filterDeterministicConjuncts; @@ -50,7 +50,6 @@ import static io.trino.sql.planner.plan.Patterns.source; import static io.trino.sql.planner.plan.Patterns.tableScan; import static java.lang.Boolean.FALSE; -import static java.lang.Boolean.TRUE; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.toList; @@ -134,12 +133,12 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) Expression resultingPredicate = createResultingPredicate( plannerContext, session, - TRUE_LITERAL, // Dynamic filters are included in decomposedPredicate.getRemainingExpression() + Booleans.TRUE, // Dynamic filters are included in decomposedPredicate.getRemainingExpression() new DomainTranslator().toPredicate(unenforcedDomain.transformKeys(assignments::get)), nonDeterministicPredicate, decomposedPredicate.getRemainingExpression()); - if (!TRUE_LITERAL.equals(resultingPredicate)) { + if (!Booleans.TRUE.equals(resultingPredicate)) { return Result.ofPlanNode(new FilterNode(context.getIdAllocator().getNextId(), node, resultingPredicate)); } @@ -150,9 +149,9 @@ private ExtractionResult getFullyExtractedPredicates(Session session, Expression { Map> extractedPredicates = extractConjuncts(predicate).stream() .map(conjunct -> DomainTranslator.getExtractionResult(plannerContext, session, conjunct)) - .collect(groupingBy(result -> result.getRemainingExpression().equals(TRUE_LITERAL), toList())); + .collect(groupingBy(result -> result.getRemainingExpression().equals(Booleans.TRUE), toList())); return new ExtractionResult( - intersect(extractedPredicates.getOrDefault(TRUE, ImmutableList.of()).stream() + intersect(extractedPredicates.getOrDefault(Boolean.TRUE, ImmutableList.of()).stream() .map(ExtractionResult::getTupleDomain) .collect(toImmutableList())), combineConjuncts( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java index 3e62b331af9d..6805f4c0ca79 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java @@ -21,8 +21,8 @@ import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.ValuesNode; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.plan.Patterns.filter; import static java.util.Collections.emptyList; @@ -42,12 +42,12 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) { Expression predicate = filterNode.getPredicate(); - if (predicate.equals(TRUE_LITERAL)) { + if (predicate.equals(TRUE)) { return Result.ofPlanNode(filterNode.getSource()); } - if (predicate.equals(FALSE_LITERAL) || - predicate instanceof Constant literal && literal.getValue() == null) { + if (predicate.equals(FALSE) || + predicate instanceof Constant literal && literal.value() == null) { return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputSymbols(), emptyList())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnreferencedScalarSubqueries.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnreferencedScalarSubqueries.java index e48fe7887e5f..cf20dd229ea4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnreferencedScalarSubqueries.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnreferencedScalarSubqueries.java @@ -20,7 +20,7 @@ import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.PlanNode; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtLeastScalar; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.trino.sql.planner.plan.JoinType.INNER; @@ -32,7 +32,7 @@ public class RemoveUnreferencedScalarSubqueries implements Rule { private static final Pattern PATTERN = correlatedJoin() - .with(filter().equalTo(TRUE_LITERAL)); + .with(filter().equalTo(TRUE)); @Override public Pattern getPattern() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index 1a7d46ab2cdc..69aedd61c23a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -24,8 +24,8 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.optimizations.PlanOptimizer; import io.trino.sql.planner.plan.DynamicFilterId; @@ -53,7 +53,7 @@ import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.DynamicFilters.getDescriptor; import static io.trino.sql.DynamicFilters.isDynamicFilter; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.combinePredicates; import static io.trino.sql.ir.IrUtils.extractConjuncts; @@ -138,7 +138,7 @@ public PlanWithConsumedDynamicFilters visitJoin(JoinNode node, Set filter = node .getFilter().map(this::removeAllDynamicFilters) // no DF support at Join operators. - .filter(expression -> !expression.equals(TRUE_LITERAL)); + .filter(expression -> !expression.equals(TRUE)); PlanNode left = leftResult.getNode(); PlanNode right = rightResult.getNode(); @@ -270,7 +270,7 @@ public PlanWithConsumedDynamicFilters visitFilter(FilterNode node, Set() { @Override - public Expression rewriteLogicalExpression(LogicalExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLogical(Logical node, Void context, ExpressionTreeRewriter treeRewriter) { - LogicalExpression rewrittenNode = treeRewriter.defaultRewrite(node, context); + Logical rewrittenNode = treeRewriter.defaultRewrite(node, context); boolean modified = (node != rewrittenNode); ImmutableList.Builder expressionBuilder = ImmutableList.builder(); - for (Expression term : rewrittenNode.getTerms()) { + for (Expression term : rewrittenNode.terms()) { if (isDynamicFilter(term)) { - expressionBuilder.add(TRUE_LITERAL); + expressionBuilder.add(TRUE); modified = true; } else { @@ -367,7 +367,7 @@ public Expression rewriteLogicalExpression(LogicalExpression node, Void context, if (!modified) { return node; } - return combinePredicates(node.getOperator(), expressionBuilder.build()); + return combinePredicates(node.operator(), expressionBuilder.build()); } }, expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java index a78a7459505f..31787b6cab8f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java @@ -31,9 +31,9 @@ import io.trino.cost.StatsProvider; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.EqualityInference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -69,8 +69,8 @@ import static io.trino.SystemSessionProperties.getJoinDistributionType; import static io.trino.SystemSessionProperties.getJoinReorderingStrategy; import static io.trino.SystemSessionProperties.getMaxReorderedJoins; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; @@ -107,7 +107,7 @@ public ReorderJoins(CostComparator costComparator) this.pattern = join().matching( joinNode -> joinNode.getDistributionType().isEmpty() && joinNode.getType() == INNER - && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL))); + && isDeterministic(joinNode.getFilter().orElse(TRUE))); } @Override @@ -268,7 +268,7 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li List joinPredicates = getJoinPredicates(leftSymbols, rightSymbols); List joinConditions = joinPredicates.stream() .filter(JoinEnumerator::isJoinEqualityCondition) - .map(predicate -> toEquiJoinClause((ComparisonExpression) predicate, leftSymbols)) + .map(predicate -> toEquiJoinClause((Comparison) predicate, leftSymbols)) .collect(toImmutableList()); if (joinConditions.isEmpty()) { return INFINITE_COST_RESULT; @@ -371,7 +371,7 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< .filter(Objects::nonNull) .forEach(predicates::add); Expression filter = combineConjuncts(predicates.build()); - if (!TRUE_LITERAL.equals(filter)) { + if (!TRUE.equals(filter)) { planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); } return createJoinEnumerationResult(planNode); @@ -381,16 +381,16 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< private static boolean isJoinEqualityCondition(Expression expression) { - return expression instanceof ComparisonExpression - && ((ComparisonExpression) expression).getOperator() == EQUAL - && ((ComparisonExpression) expression).getLeft() instanceof SymbolReference - && ((ComparisonExpression) expression).getRight() instanceof SymbolReference; + return expression instanceof Comparison + && ((Comparison) expression).operator() == EQUAL + && ((Comparison) expression).left() instanceof Reference + && ((Comparison) expression).right() instanceof Reference; } - private static EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Set leftSymbols) + private static EquiJoinClause toEquiJoinClause(Comparison equality, Set leftSymbols) { - Symbol leftSymbol = Symbol.from(equality.getLeft()); - Symbol rightSymbol = Symbol.from(equality.getRight()); + Symbol leftSymbol = Symbol.from(equality.left()); + Symbol rightSymbol = Symbol.from(equality.right()); EquiJoinClause equiJoinClause = new EquiJoinClause(leftSymbol, rightSymbol); return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip(); } @@ -616,7 +616,7 @@ private void flattenNode(PlanNode node, int limit) return; } - if (joinNode.getType() != INNER || !isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL)) || joinNode.getDistributionType().isPresent()) { + if (joinNode.getType() != INNER || !isDeterministic(joinNode.getFilter().orElse(TRUE)) || joinNode.getDistributionType().isPresent()) { sources.add(node); return; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java index 82f3985ec3f4..35fb64c4d980 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java @@ -32,7 +32,7 @@ import java.util.Map; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; import static io.trino.sql.planner.plan.Patterns.join; @@ -147,7 +147,7 @@ public Result apply(JoinNode node, Captures captures, Context context) private static boolean isUnconditional(JoinNode joinNode) { return joinNode.getCriteria().isEmpty() && - (joinNode.getFilter().isEmpty() || joinNode.getFilter().get().equals(TRUE_LITERAL)); + (joinNode.getFilter().isEmpty() || joinNode.getFilter().get().equals(TRUE)); } private boolean canInlineJoinSource(PlanNode source) @@ -186,7 +186,7 @@ private ProjectNode appendProjection(PlanNode source, List sourceOutputs Map mapping = new HashMap<>(); for (int i = 0; i < values.getOutputSymbols().size(); i++) { - mapping.put(values.getOutputSymbols().get(i), row.getItems().get(i)); + mapping.put(values.getOutputSymbols().get(i), row.items().get(i)); } Assignments.Builder assignments = Assignments.builder() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index 1862069c808e..bcf0293f3406 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -22,9 +22,9 @@ import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.TypeSignature; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.planner.BuiltinFunctionCallBuilder; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -149,10 +149,10 @@ public Result apply(AggregationNode node, Captures captures, Context context) private boolean isStEnvelopeFunctionCall(Expression expression, ResolvedFunction stEnvelopeFunction) { - if (!(expression instanceof FunctionCall functionCall)) { + if (!(expression instanceof Call call)) { return false; } - return functionCall.getFunction().getFunctionId().equals(stEnvelopeFunction.getFunctionId()); + return call.function().getFunctionId().equals(stEnvelopeFunction.getFunctionId()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java index 20242a9e2043..985be903c6c4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java @@ -23,7 +23,7 @@ import io.trino.spi.type.Type; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -46,7 +46,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_FOLLOWING; @@ -133,18 +133,18 @@ private List appendMarkers(List markers, List nodes, return result.build(); } - private static PlanNode appendMarkers(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, PlanNode source, int markerIndex, List markers, Map projections) + private static PlanNode appendMarkers(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, PlanNode source, int markerIndex, List markers, Map projections) { Assignments.Builder assignments = Assignments.builder(); // add existing intersect symbols to projection - for (Map.Entry entry : projections.entrySet()) { + for (Map.Entry entry : projections.entrySet()) { Symbol symbol = symbolAllocator.newSymbol(entry.getKey().getName(), entry.getKey().getType()); assignments.put(symbol, entry.getValue()); } // add extra marker fields to the projection for (int i = 0; i < markers.size(); ++i) { - Expression expression = (i == markerIndex) ? TRUE_LITERAL : new Constant(BOOLEAN, null); + Expression expression = (i == markerIndex) ? TRUE : new Constant(BOOLEAN, null); assignments.put(symbolAllocator.newSymbol(markers.get(i).getName(), BOOLEAN), expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java index bca4408c2fe8..a543e073086d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -23,7 +23,7 @@ import io.trino.sql.PlannerContext; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.AggregationNode; @@ -109,10 +109,10 @@ private boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Ass } Expression argument = aggregation.getArguments().get(0); - if (argument instanceof SymbolReference) { + if (argument instanceof Reference) { argument = inputs.get(Symbol.from(argument)); } - return argument instanceof Constant constant && constant.getValue() != null; + return argument instanceof Constant constant && constant.value() != null; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java index 8dd0e7bc5d5f..cb490bbd7fb5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java @@ -18,7 +18,7 @@ import io.trino.sql.PlannerContext; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.iterative.Rule; @@ -36,7 +36,7 @@ public class SimplifyExpressions public static Expression rewrite(Expression expression, Session session, PlannerContext plannerContext) { requireNonNull(plannerContext, "plannerContext is null"); - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return expression; } expression = pushDownNegations(expression); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java index 1c73688217f9..ba3860dc3257 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java @@ -16,14 +16,14 @@ import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.FilterNode; @@ -33,8 +33,8 @@ import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; @@ -68,9 +68,9 @@ public Result apply(FilterNode node, Captures captures, Context context) boolean simplified = false; for (Expression conjunct : conjuncts) { Optional simplifiedConjunct = switch (conjunct) { - case NullIfExpression expression -> Optional.of(LogicalExpression.and(expression.getFirst(), isFalseOrNullPredicate(expression.getSecond()))); - case SearchedCaseExpression expression -> simplify(expression); - case SimpleCaseExpression expression -> simplify(expression); + case NullIf expression -> Optional.of(Logical.and(expression.first(), isFalseOrNullPredicate(expression.second()))); + case Case expression -> simplify(expression); + case Switch expression -> simplify(expression); case null, default -> Optional.empty(); }; @@ -94,64 +94,64 @@ public Result apply(FilterNode node, Captures captures, Context context) private static Optional simplify(Expression condition, Expression trueValue, Optional falseValue) { - if (trueValue.equals(TRUE_LITERAL) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) { + if (trueValue.equals(TRUE) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) { return Optional.of(condition); } - if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE_LITERAL)) { + if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE)) { return Optional.of(isFalseOrNullPredicate(condition)); } if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue)) { return Optional.of(trueValue); } if (isNotTrue(trueValue) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) { - return Optional.of(FALSE_LITERAL); + return Optional.of(FALSE); } - if (condition.equals(TRUE_LITERAL)) { + if (condition.equals(TRUE)) { return Optional.of(trueValue); } if (isNotTrue(condition)) { - return Optional.of(falseValue.orElse(FALSE_LITERAL)); + return Optional.of(falseValue.orElse(FALSE)); } return Optional.empty(); } - private static Optional simplify(SearchedCaseExpression caseExpression) + private static Optional simplify(Case caseExpression) { - Optional defaultValue = caseExpression.getDefaultValue(); + Optional defaultValue = caseExpression.defaultValue(); - if (caseExpression.getWhenClauses().size() == 1) { + if (caseExpression.whenClauses().size() == 1) { // if-like expression return simplify( - caseExpression.getWhenClauses().getFirst().getOperand(), - caseExpression.getWhenClauses().getFirst().getResult(), + caseExpression.whenClauses().getFirst().getOperand(), + caseExpression.whenClauses().getFirst().getResult(), defaultValue); } - List operands = caseExpression.getWhenClauses().stream() + List operands = caseExpression.whenClauses().stream() .map(WhenClause::getOperand) .collect(toImmutableList()); - List results = caseExpression.getWhenClauses().stream() + List results = caseExpression.whenClauses().stream() .map(WhenClause::getResult) .collect(toImmutableList()); long trueResultsCount = results.stream() - .filter(result -> result.equals(TRUE_LITERAL)) + .filter(result -> result.equals(TRUE)) .count(); long notTrueResultsCount = results.stream() .filter(SimplifyFilterPredicate::isNotTrue) .count(); // all results true - if (trueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) { - return Optional.of(TRUE_LITERAL); + if (trueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE)) { + return Optional.of(TRUE); } // all results not true if (notTrueResultsCount == results.size() && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) { - return Optional.of(FALSE_LITERAL); + return Optional.of(FALSE); } // one result true, and remaining results not true if (trueResultsCount == 1 && notTrueResultsCount == results.size() - 1 && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) { ImmutableList.Builder builder = ImmutableList.builder(); - for (WhenClause whenClause : caseExpression.getWhenClauses()) { + for (WhenClause whenClause : caseExpression.whenClauses()) { Expression operand = whenClause.getOperand(); Expression result = whenClause.getResult(); if (isNotTrue(result)) { @@ -164,62 +164,62 @@ private static Optional simplify(SearchedCaseExpression caseExpressi } } // all results not true, and default true - if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) { + if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE)) { ImmutableList.Builder builder = ImmutableList.builder(); operands.forEach(operand -> builder.add(isFalseOrNullPredicate(operand))); return Optional.of(combineConjuncts(builder.build())); } // skip clauses with not true conditions List whenClauses = new ArrayList<>(); - for (WhenClause whenClause : caseExpression.getWhenClauses()) { + for (WhenClause whenClause : caseExpression.whenClauses()) { Expression operand = whenClause.getOperand(); - if (operand.equals(TRUE_LITERAL)) { + if (operand.equals(TRUE)) { if (whenClauses.isEmpty()) { return Optional.of(whenClause.getResult()); } - return Optional.of(new SearchedCaseExpression(whenClauses, Optional.of(whenClause.getResult()))); + return Optional.of(new Case(whenClauses, Optional.of(whenClause.getResult()))); } if (!isNotTrue(operand)) { whenClauses.add(whenClause); } } if (whenClauses.isEmpty()) { - return Optional.of(defaultValue.orElse(FALSE_LITERAL)); + return Optional.of(defaultValue.orElse(FALSE)); } - if (whenClauses.size() < caseExpression.getWhenClauses().size()) { - return Optional.of(new SearchedCaseExpression(whenClauses, defaultValue)); + if (whenClauses.size() < caseExpression.whenClauses().size()) { + return Optional.of(new Case(whenClauses, defaultValue)); } return Optional.empty(); } - private static Optional simplify(SimpleCaseExpression caseExpression) + private static Optional simplify(Switch caseExpression) { - Optional defaultValue = caseExpression.getDefaultValue(); + Optional defaultValue = caseExpression.defaultValue(); - if (caseExpression.getOperand() instanceof Constant literal && literal.getValue() == null) { - return Optional.of(defaultValue.orElse(FALSE_LITERAL)); + if (caseExpression.operand() instanceof Constant literal && literal.value() == null) { + return Optional.of(defaultValue.orElse(FALSE)); } - List results = caseExpression.getWhenClauses().stream() + List results = caseExpression.whenClauses().stream() .map(WhenClause::getResult) .collect(toImmutableList()); - if (results.stream().allMatch(result -> result.equals(TRUE_LITERAL)) && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) { - return Optional.of(TRUE_LITERAL); + if (results.stream().allMatch(result -> result.equals(TRUE)) && defaultValue.isPresent() && defaultValue.get().equals(TRUE)) { + return Optional.of(TRUE); } if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) { - return Optional.of(FALSE_LITERAL); + return Optional.of(FALSE); } return Optional.empty(); } private static boolean isNotTrue(Expression expression) { - return expression.equals(FALSE_LITERAL) || - expression instanceof Constant literal && literal.getValue() == null; + return expression.equals(FALSE) || + expression instanceof Constant literal && literal.value() == null; } private static Expression isFalseOrNullPredicate(Expression expression) { - return LogicalExpression.or(new IsNullPredicate(expression), new NotExpression(expression)); + return Logical.or(new IsNull(expression), new Not(expression)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java index 474e9645d536..37d1a05e3cd2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java @@ -41,7 +41,7 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinType.LEFT; import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter; @@ -83,7 +83,7 @@ public class TransformCorrelatedDistinctAggregationWithProjection private static final Pattern PATTERN = correlatedJoin() .with(type().equalTo(LEFT)) .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .with(subquery().matching(project() .capturedAs(PROJECTION) .with(source().matching(aggregation() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java index f17f35235c79..5d709da22c34 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java @@ -36,7 +36,7 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -76,7 +76,7 @@ public class TransformCorrelatedDistinctAggregationWithoutProjection private static final Pattern PATTERN = correlatedJoin() .with(type().equalTo(LEFT)) .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .with(subquery().matching(aggregation() .matching(AggregationDecorrelation::isDistinctOperator) .capturedAs(AGGREGATION))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java index 38e33d6d19e1..5e7f8d21669e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java @@ -47,7 +47,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; @@ -127,7 +127,7 @@ public class TransformCorrelatedGlobalAggregationWithProjection private static final Pattern PATTERN = correlatedJoin() .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .with(subquery().matching(project() .capturedAs(PROJECTION) .with(source().matching(aggregation() @@ -183,7 +183,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co source, Assignments.builder() .putIdentities(source.getOutputSymbols()) - .put(nonNull, TRUE_LITERAL) + .put(nonNull, TRUE) .build()); // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java index 0dfd73444ada..678563bf92e2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java @@ -44,7 +44,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; @@ -122,7 +122,7 @@ public class TransformCorrelatedGlobalAggregationWithoutProjection private static final Pattern PATTERN = correlatedJoin() .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) - .with(filter().equalTo(TRUE_LITERAL)) // todo non-trivial join filter: adding filter/project on top of aggregation + .with(filter().equalTo(TRUE)) // todo non-trivial join filter: adding filter/project on top of aggregation .with(subquery().matching(aggregation() .with(empty(groupingColumns())) .with(source().capturedAs(SOURCE)) @@ -175,7 +175,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co source, Assignments.builder() .putIdentities(source.getOutputSymbols()) - .put(nonNull, TRUE_LITERAL) + .put(nonNull, TRUE) .build()); // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java index b337b19905c9..3695a1105f8d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java @@ -41,7 +41,7 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -118,7 +118,7 @@ public class TransformCorrelatedGroupedAggregationWithProjection private static final Pattern PATTERN = correlatedJoin() .with(type().equalTo(INNER)) .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .with(subquery().matching(project() .capturedAs(PROJECTION) .with(source().matching(aggregation() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java index f70dff7159fd..f7a85802a918 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java @@ -36,7 +36,7 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; @@ -111,7 +111,7 @@ public class TransformCorrelatedGroupedAggregationWithoutProjection private static final Pattern PATTERN = correlatedJoin() .with(type().equalTo(INNER)) .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) - .with(filter().equalTo(TRUE_LITERAL)) + .with(filter().equalTo(TRUE)) .with(subquery().matching(aggregation() .with(nonEmpty(groupingColumns())) .matching(aggregation -> aggregation.getGroupingSetCount() == 1) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 1fa8d578726a..9a61811f8c94 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -19,14 +19,14 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.Metadata; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Case; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrUtils; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -177,9 +177,9 @@ private PlanNode buildInPredicateEquivalent( Expression joinExpression = and( or( - new IsNullPredicate(probeSideSymbol.toSymbolReference()), - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), - new IsNullPredicate(buildSideSymbol.toSymbolReference())), + new IsNull(probeSideSymbol.toSymbolReference()), + new Comparison(Comparison.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), + new IsNull(buildSideSymbol.toSymbolReference())), correlationCondition); JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression); @@ -216,7 +216,7 @@ private PlanNode buildInPredicateEquivalent( singleGroupingSet(probeSide.getOutputSymbols())); // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results - SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression( + Case inPredicateEquivalent = new Case( ImmutableList.of( new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), @@ -263,20 +263,20 @@ private AggregationNode.Aggregation countWithFilter(Symbol filter) private static Expression isGreaterThan(Symbol symbol, long value) { - return new ComparisonExpression( - ComparisonExpression.Operator.GREATER_THAN, + return new Comparison( + Comparison.Operator.GREATER_THAN, symbol.toSymbolReference(), bigint(value)); } private static Expression not(Expression booleanExpression) { - return new NotExpression(booleanExpression); + return new Not(booleanExpression); } private static Expression isNotNull(Symbol symbol) { - return new NotExpression(new IsNullPredicate(symbol.toSymbolReference())); + return new Not(new IsNull(symbol.toSymbolReference())); } private static Expression bigint(long value) @@ -325,8 +325,8 @@ public Optional visitProject(ProjectNode node, PlanNode reference) // Pull up all symbols used by a filter (except correlation) decorrelated.getCorrelatedPredicates().stream() .flatMap(IrUtils::preOrder) - .filter(SymbolReference.class::isInstance) - .map(SymbolReference.class::cast) + .filter(Reference.class::isInstance) + .map(Reference.class::cast) .filter(symbolReference -> !correlation.contains(Symbol.from(symbolReference))) .forEach(symbolReference -> assignments.putIdentity(Symbol.from(symbolReference))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java index 2c91df0c53fc..9da0239a5233 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java @@ -30,7 +30,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.matching.Pattern.nonEmpty; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -75,7 +75,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co DecorrelatedNode decorrelatedSubquery = decorrelatedNodeOptional.get(); Expression filter = combineConjuncts( - decorrelatedSubquery.getCorrelatedPredicates().orElse(TRUE_LITERAL), + decorrelatedSubquery.getCorrelatedPredicates().orElse(TRUE), correlatedJoinNode.getFilter()); return Result.ofPlanNode(new JoinNode( @@ -87,7 +87,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co correlatedJoinNode.getInput().getOutputSymbols(), correlatedJoinNode.getSubquery().getOutputSymbols(), false, - filter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(filter), + filter.equals(TRUE) ? Optional.empty() : Optional.of(filter), Optional.empty(), Optional.empty(), Optional.empty(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index dcc1412990cc..170eadb83e20 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -19,7 +19,7 @@ import io.trino.metadata.Metadata; import io.trino.spi.type.BigintType; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.SimpleCaseExpression; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -39,7 +39,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; @@ -83,7 +83,7 @@ public class TransformCorrelatedScalarSubquery { private static final Pattern PATTERN = correlatedJoin() .with(nonEmpty(correlation())) - .with(filter().equalTo(TRUE_LITERAL)); + .with(filter().equalTo(TRUE)); private final Metadata metadata; @@ -159,10 +159,10 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co FilterNode filterNode = new FilterNode( context.getIdAllocator().getNextId(), markDistinctNode, - new SimpleCaseExpression( + new Switch( isDistinct.toSymbolReference(), ImmutableList.of( - new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), + new WhenClause(TRUE, TRUE)), Optional.of(new Cast( failFunction(metadata, SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), BOOLEAN)))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java index 1525325da548..a1144fa87f2a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java @@ -24,7 +24,7 @@ import io.trino.sql.planner.plan.ValuesNode; import static com.google.common.collect.Streams.forEachPair; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter; import static io.trino.sql.planner.plan.Patterns.correlatedJoin; @@ -61,7 +61,7 @@ public class TransformCorrelatedSingleRowSubqueryToProject implements Rule { private static final Pattern PATTERN = correlatedJoin() - .with(filter().equalTo(TRUE_LITERAL)); + .with(filter().equalTo(TRUE)); @Override public Pattern getPattern() @@ -93,7 +93,7 @@ public Result apply(CorrelatedJoinNode parent, Captures captures, Context contex .putIdentities(parent.getInput().getOutputSymbols()); forEachPair( values.getOutputSymbols().stream(), - row.getItems().stream(), + row.items().stream(), assignments::put); return Result.ofPlanNode(projectNode(parent.getInput(), assignments.build(), context)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java index b2797fcc64ac..5b96246c251f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java @@ -19,9 +19,9 @@ import io.trino.matching.Pattern; import io.trino.metadata.ResolvedFunction; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BooleanLiteral; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -40,8 +40,8 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.JoinType.INNER; @@ -133,7 +133,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C applyNode.getSubquery(), 1L, false), - Assignments.of(subqueryTrue, TRUE_LITERAL)); + Assignments.of(subqueryTrue, TRUE)); PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); if (decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isEmpty()) { @@ -143,7 +143,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C Symbol exists = getOnlyElement(applyNode.getSubqueryAssignments().keySet()); Assignments.Builder assignments = Assignments.builder() .putIdentities(applyNode.getInput().getOutputSymbols()) - .put(exists, new CoalesceExpression(ImmutableList.of(subqueryTrue.toSymbolReference(), BooleanLiteral.FALSE_LITERAL))); + .put(exists, new Coalesce(ImmutableList.of(subqueryTrue.toSymbolReference(), Booleans.FALSE))); return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new CorrelatedJoinNode( @@ -152,7 +152,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C subquery, applyNode.getCorrelation(), LEFT, - TRUE_LITERAL, + TRUE, applyNode.getOriginSubquery()), assignments.build())); } @@ -179,10 +179,10 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Context contex Optional.empty(), Optional.empty())), globalAggregation()), - Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Constant(BIGINT, 0L)))), + Assignments.of(exists, new Comparison(GREATER_THAN, count.toSymbolReference(), new Constant(BIGINT, 0L)))), applyNode.getCorrelation(), INNER, - TRUE_LITERAL, + TRUE, applyNode.getOriginSubquery()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java index 5d736e8d66c4..4eda7d738873 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java @@ -39,7 +39,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.SystemSessionProperties.isRewriteFilteringSemiJoinToInnerJoin; import static io.trino.matching.Capture.newCapture; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols; @@ -116,12 +116,12 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) Expression simplifiedPredicate = inlineSymbols(symbol -> { if (symbol.equals(semiJoinSymbol)) { - return TRUE_LITERAL; + return TRUE; } return symbol.toSymbolReference(); }, filteredPredicate); - Optional joinFilter = simplifiedPredicate.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(simplifiedPredicate); + Optional joinFilter = simplifiedPredicate.equals(TRUE) ? Optional.empty() : Optional.of(simplifiedPredicate); PlanNode filteringSourceDistinct = singleAggregation( context.getIdAllocator().getNextId(), @@ -153,7 +153,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) innerJoin, Assignments.builder() .putIdentities(innerJoin.getOutputSymbols()) - .put(semiJoinSymbol, TRUE_LITERAL) + .put(semiJoinSymbol, TRUE) .build()); return Result.ofPlanNode(project); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java index 539e57df02bc..fe34f2945a8d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java @@ -35,7 +35,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.matching.Pattern.empty; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; import static io.trino.sql.planner.plan.JoinType.FULL; @@ -81,9 +81,9 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co else { type = JoinType.LEFT; } - JoinNode joinNode = rewriteToJoin(correlatedJoinNode, type, TRUE_LITERAL, context.getLookup()); + JoinNode joinNode = rewriteToJoin(correlatedJoinNode, type, TRUE, context.getLookup()); - if (correlatedJoinNode.getFilter().equals(TRUE_LITERAL)) { + if (correlatedJoinNode.getFilter().equals(TRUE)) { return Result.ofPlanNode(joinNode); } @@ -113,7 +113,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co private JoinNode rewriteToJoin(CorrelatedJoinNode parent, JoinType type, Expression filter, Lookup lookup) { - if (type == JoinType.LEFT && extractCardinality(parent.getSubquery(), lookup).isAtLeastScalar() && filter.equals(TRUE_LITERAL)) { + if (type == JoinType.LEFT && extractCardinality(parent.getSubquery(), lookup).isAtLeastScalar() && filter.equals(TRUE)) { // input rows will always be matched against subquery rows type = JoinType.INNER; } @@ -126,7 +126,7 @@ private JoinNode rewriteToJoin(CorrelatedJoinNode parent, JoinType type, Express parent.getInput().getOutputSymbols(), parent.getSubquery().getOutputSymbols(), false, - filter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(filter), + filter.equals(TRUE) ? Optional.empty() : Optional.of(filter), Optional.empty(), Optional.empty(), Optional.empty(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java index 0cc39837c1da..01b7aa3bd1d9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -36,12 +36,12 @@ import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.type.TypeCoercion; @@ -69,13 +69,13 @@ import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; import static io.trino.spi.type.TypeUtils.isFloatingPointNaN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.or; import static java.lang.Float.intBitsToFloat; @@ -153,28 +153,28 @@ public Visitor(PlannerContext plannerContext, Session session) } @Override - public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter treeRewriter) { - ComparisonExpression expression = treeRewriter.defaultRewrite(node, null); + Comparison expression = treeRewriter.defaultRewrite(node, null); return unwrapCast(expression); } - private Expression unwrapCast(ComparisonExpression expression) + private Expression unwrapCast(Comparison expression) { // Canonicalization is handled by CanonicalizeExpressionRewriter - if (!(expression.getLeft() instanceof Cast cast)) { + if (!(expression.left() instanceof Cast cast)) { return expression; } - Object right = new IrExpressionInterpreter(expression.getRight(), plannerContext, session) + Object right = new IrExpressionInterpreter(expression.right(), plannerContext, session) .optimize(NoOpSymbolResolver.INSTANCE); - ComparisonExpression.Operator operator = expression.getOperator(); + Comparison.Operator operator = expression.operator(); if (right == null) { return switch (operator) { case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Constant(BOOLEAN, null); - case IS_DISTINCT_FROM -> new NotExpression(new IsNullPredicate(cast)); + case IS_DISTINCT_FROM -> new Not(new IsNull(cast)); }; } @@ -186,7 +186,7 @@ private Expression unwrapCast(ComparisonExpression expression) Type targetType = expression.right().type(); if (sourceType instanceof TimestampType && targetType == DATE) { - return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.getExpression(), (long) right).orElse(expression); + return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.expression(), (long) right).orElse(expression); } if (targetType instanceof TimestampWithTimeZoneType) { @@ -207,12 +207,12 @@ private Expression unwrapCast(ComparisonExpression expression) case GREATER_THAN_OR_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: - return falseIfNotNull(cast.getExpression()); + return falseIfNotNull(cast.expression()); case NOT_EQUAL: - return trueIfNotNull(cast.getExpression()); + return trueIfNotNull(cast.expression()); case IS_DISTINCT_FROM: if (!typeHasNaN(sourceType)) { - return TRUE_LITERAL; + return TRUE; } // NaN on the right of comparison will be cast to source type later break; @@ -240,21 +240,21 @@ private Expression unwrapCast(ComparisonExpression expression) if (upperBoundComparison > 0) { // larger than maximum representable value return switch (operator) { - case EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> falseIfNotNull(cast.getExpression()); - case NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression()); - case IS_DISTINCT_FROM -> TRUE_LITERAL; + case EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> falseIfNotNull(cast.expression()); + case NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); + case IS_DISTINCT_FROM -> TRUE; }; } if (upperBoundComparison == 0) { // equal to max representable value return switch (operator) { - case GREATER_THAN -> falseIfNotNull(cast.getExpression()); - case GREATER_THAN_OR_EQUAL -> new ComparisonExpression(EQUAL, cast.getExpression(), new Constant(sourceType, max)); - case LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression()); - case LESS_THAN -> new ComparisonExpression(NOT_EQUAL, cast.getExpression(), new Constant(sourceType, max)); + case GREATER_THAN -> falseIfNotNull(cast.expression()); + case GREATER_THAN_OR_EQUAL -> new Comparison(EQUAL, cast.expression(), new Constant(sourceType, max)); + case LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); + case LESS_THAN -> new Comparison(NOT_EQUAL, cast.expression(), new Constant(sourceType, max)); case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> - new ComparisonExpression(operator, cast.getExpression(), new Constant(sourceType, max)); + new Comparison(operator, cast.expression(), new Constant(sourceType, max)); }; } @@ -265,21 +265,21 @@ private Expression unwrapCast(ComparisonExpression expression) if (lowerBoundComparison < 0) { // smaller than minimum representable value return switch (operator) { - case NOT_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression()); - case EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> falseIfNotNull(cast.getExpression()); - case IS_DISTINCT_FROM -> TRUE_LITERAL; + case NOT_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); + case EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> falseIfNotNull(cast.expression()); + case IS_DISTINCT_FROM -> TRUE; }; } if (lowerBoundComparison == 0) { // equal to min representable value return switch (operator) { - case LESS_THAN -> falseIfNotNull(cast.getExpression()); - case LESS_THAN_OR_EQUAL -> new ComparisonExpression(EQUAL, cast.getExpression(), new Constant(sourceType, min)); - case GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression()); - case GREATER_THAN -> new ComparisonExpression(NOT_EQUAL, cast.getExpression(), new Constant(sourceType, min)); + case LESS_THAN -> falseIfNotNull(cast.expression()); + case LESS_THAN_OR_EQUAL -> new Comparison(EQUAL, cast.expression(), new Constant(sourceType, min)); + case GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); + case GREATER_THAN -> new Comparison(NOT_EQUAL, cast.expression(), new Constant(sourceType, min)); case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> - new ComparisonExpression(operator, cast.getExpression(), new Constant(sourceType, min)); + new Comparison(operator, cast.expression(), new Constant(sourceType, min)); }; } } @@ -316,43 +316,43 @@ private Expression unwrapCast(ComparisonExpression expression) if (literalVsRoundtripped > 0) { // cast rounded down return switch (operator) { - case EQUAL -> falseIfNotNull(cast.getExpression()); - case NOT_EQUAL -> trueIfNotNull(cast.getExpression()); - case IS_DISTINCT_FROM -> TRUE_LITERAL; + case EQUAL -> falseIfNotNull(cast.expression()); + case NOT_EQUAL -> trueIfNotNull(cast.expression()); + case IS_DISTINCT_FROM -> TRUE; case LESS_THAN, LESS_THAN_OR_EQUAL -> { if (sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMin(), literalInSourceType) == 0) { - yield new ComparisonExpression(EQUAL, cast.getExpression(), new Constant(sourceType, literalInSourceType)); + yield new Comparison(EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)); } - yield new ComparisonExpression(LESS_THAN_OR_EQUAL, cast.getExpression(), new Constant(sourceType, literalInSourceType)); + yield new Comparison(LESS_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)); } case GREATER_THAN, GREATER_THAN_OR_EQUAL -> // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value // larger than the next value in the source type - new ComparisonExpression(GREATER_THAN, cast.getExpression(), new Constant(sourceType, literalInSourceType)); + new Comparison(GREATER_THAN, cast.expression(), new Constant(sourceType, literalInSourceType)); }; } if (literalVsRoundtripped < 0) { // cast rounded up return switch (operator) { - case EQUAL -> falseIfNotNull(cast.getExpression()); - case NOT_EQUAL -> trueIfNotNull(cast.getExpression()); - case IS_DISTINCT_FROM -> TRUE_LITERAL; + case EQUAL -> falseIfNotNull(cast.expression()); + case NOT_EQUAL -> trueIfNotNull(cast.expression()); + case IS_DISTINCT_FROM -> TRUE; case LESS_THAN, LESS_THAN_OR_EQUAL -> // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value // smaller than the next value in the source type - new ComparisonExpression(LESS_THAN, cast.getExpression(), new Constant(sourceType, literalInSourceType)); + new Comparison(LESS_THAN, cast.expression(), new Constant(sourceType, literalInSourceType)); case GREATER_THAN, GREATER_THAN_OR_EQUAL -> sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMax(), literalInSourceType) == 0 ? - new ComparisonExpression(EQUAL, cast.getExpression(), new Constant(sourceType, literalInSourceType)) : - new ComparisonExpression(GREATER_THAN_OR_EQUAL, cast.getExpression(), new Constant(sourceType, literalInSourceType)); + new Comparison(EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)) : + new Comparison(GREATER_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)); }; } } - return new ComparisonExpression(operator, cast.getExpression(), new Constant(sourceType, literalInSourceType)); + return new Comparison(operator, cast.expression(), new Constant(sourceType, literalInSourceType)); } - private Optional unwrapTimestampToDateCast(TimestampType sourceType, ComparisonExpression.Operator operator, Expression timestampExpression, long date) + private Optional unwrapTimestampToDateCast(TimestampType sourceType, Comparison.Operator operator, Expression timestampExpression, long date) { ResolvedFunction targetToSource; try { @@ -368,21 +368,21 @@ private Optional unwrapTimestampToDateCast(TimestampType sourceType, return switch (operator) { case EQUAL -> Optional.of( and( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp), - new ComparisonExpression(LESS_THAN, timestampExpression, nextDateTimestamp))); + new Comparison(GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp), + new Comparison(LESS_THAN, timestampExpression, nextDateTimestamp))); case NOT_EQUAL -> Optional.of( or( - new ComparisonExpression(LESS_THAN, timestampExpression, dateTimestamp), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp))); - case LESS_THAN -> Optional.of(new ComparisonExpression(LESS_THAN, timestampExpression, dateTimestamp)); - case LESS_THAN_OR_EQUAL -> Optional.of(new ComparisonExpression(LESS_THAN, timestampExpression, nextDateTimestamp)); - case GREATER_THAN -> Optional.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)); - case GREATER_THAN_OR_EQUAL -> Optional.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp)); + new Comparison(LESS_THAN, timestampExpression, dateTimestamp), + new Comparison(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp))); + case LESS_THAN -> Optional.of(new Comparison(LESS_THAN, timestampExpression, dateTimestamp)); + case LESS_THAN_OR_EQUAL -> Optional.of(new Comparison(LESS_THAN, timestampExpression, nextDateTimestamp)); + case GREATER_THAN -> Optional.of(new Comparison(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)); + case GREATER_THAN_OR_EQUAL -> Optional.of(new Comparison(GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp)); case IS_DISTINCT_FROM -> Optional.of( or( - new IsNullPredicate(timestampExpression), - new ComparisonExpression(LESS_THAN, timestampExpression, dateTimestamp), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp))); + new IsNull(timestampExpression), + new Comparison(LESS_THAN, timestampExpression, dateTimestamp), + new Comparison(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp))); }; } @@ -553,11 +553,11 @@ private static Instant getInstantWithTruncation(TimestampWithTimeZoneType type, public static Expression falseIfNotNull(Expression argument) { - return and(new IsNullPredicate(argument), new Constant(BOOLEAN, null)); + return and(new IsNull(argument), new Constant(BOOLEAN, null)); } public static Expression trueIfNotNull(Expression argument) { - return or(new NotExpression(new IsNullPredicate(argument)), new Constant(BOOLEAN, null)); + return or(new Not(new IsNull(argument)), new Constant(BOOLEAN, null)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java index 72b8ba0f0de8..36339ad65a07 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java @@ -28,14 +28,14 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; @@ -55,11 +55,11 @@ import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.TimestampType.createTimestampType; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.or; import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.falseIfNotNull; import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.trueIfNotNull; @@ -125,25 +125,25 @@ public Visitor(PlannerContext plannerContext, Session session) } @Override - public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter treeRewriter) { - ComparisonExpression expression = treeRewriter.defaultRewrite(node, null); + Comparison expression = treeRewriter.defaultRewrite(node, null); return unwrapDateTrunc(expression); } // Simplify `date_trunc(unit, d) ? value` - private Expression unwrapDateTrunc(ComparisonExpression expression) + private Expression unwrapDateTrunc(Comparison expression) { // Expect date_trunc on the left side and value on the right side of the comparison. // This is provided by CanonicalizeExpressionRewriter. - if (!(expression.getLeft() instanceof FunctionCall call) || - !call.getFunction().getName().equals(builtinFunctionName("date_trunc")) || - call.getArguments().size() != 2) { + if (!(expression.left() instanceof Call call) || + !call.function().getName().equals(builtinFunctionName("date_trunc")) || + call.arguments().size() != 2) { return expression; } - Expression unitExpression = call.getArguments().get(0); + Expression unitExpression = call.arguments().get(0); if (!(unitExpression.type() instanceof VarcharType) || !(unitExpression instanceof Constant)) { return expression; } @@ -153,19 +153,19 @@ private Expression unwrapDateTrunc(ComparisonExpression expression) return expression; } - Expression argument = call.getArguments().get(1); + Expression argument = call.arguments().get(1); Type argumentType = argument.type(); Type rightType = expression.right().type(); verify(argumentType.equals(rightType), "Mismatched types: %s and %s", argumentType, rightType); - Object right = new IrExpressionInterpreter(expression.getRight(), plannerContext, session) + Object right = new IrExpressionInterpreter(expression.right(), plannerContext, session) .optimize(NoOpSymbolResolver.INSTANCE); if (right == null) { - return switch (expression.getOperator()) { + return switch (expression.operator()) { case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Constant(BOOLEAN, null); - case IS_DISTINCT_FROM -> new NotExpression(new IsNullPredicate(argument)); + case IS_DISTINCT_FROM -> new Not(new IsNull(argument)); }; } @@ -178,7 +178,7 @@ private Expression unwrapDateTrunc(ComparisonExpression expression) return expression; } - ResolvedFunction resolvedFunction = call.getFunction(); + ResolvedFunction resolvedFunction = call.function(); Optional unitIfSupported = Enums.getIfPresent(SupportedUnit.class, unitName.toStringUtf8().toUpperCase(Locale.ENGLISH)).toJavaUtil(); if (unitIfSupported.isEmpty()) { @@ -195,7 +195,7 @@ private Expression unwrapDateTrunc(ComparisonExpression expression) verify(compare <= 0, "Truncation of %s value %s resulted in a bigger value %s", rightType, right, rangeLow); boolean rightValueAtRangeLow = compare == 0; - return switch (expression.getOperator()) { + return switch (expression.operator()) { case EQUAL -> { if (!rightValueAtRangeLow) { yield falseIfNotNull(argument); @@ -206,33 +206,33 @@ private Expression unwrapDateTrunc(ComparisonExpression expression) if (!rightValueAtRangeLow) { yield trueIfNotNull(argument); } - yield new NotExpression(between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit))); + yield new Not(between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit))); } case IS_DISTINCT_FROM -> { if (!rightValueAtRangeLow) { - yield TRUE_LITERAL; + yield TRUE; } yield or( - new IsNullPredicate(argument), - new NotExpression(between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit)))); + new IsNull(argument), + new Not(between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit)))); } case LESS_THAN -> { if (rightValueAtRangeLow) { - yield new ComparisonExpression(LESS_THAN, argument, new Constant(rightType, rangeLow)); + yield new Comparison(LESS_THAN, argument, new Constant(rightType, rangeLow)); } - yield new ComparisonExpression(LESS_THAN_OR_EQUAL, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); + yield new Comparison(LESS_THAN_OR_EQUAL, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); } case LESS_THAN_OR_EQUAL -> { - yield new ComparisonExpression(LESS_THAN_OR_EQUAL, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); + yield new Comparison(LESS_THAN_OR_EQUAL, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); } case GREATER_THAN -> { - yield new ComparisonExpression(GREATER_THAN, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); + yield new Comparison(GREATER_THAN, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); } case GREATER_THAN_OR_EQUAL -> { if (rightValueAtRangeLow) { - yield new ComparisonExpression(GREATER_THAN_OR_EQUAL, argument, new Constant(rightType, rangeLow)); + yield new Comparison(GREATER_THAN_OR_EQUAL, argument, new Constant(rightType, rangeLow)); } - yield new ComparisonExpression(GREATER_THAN, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); + yield new Comparison(GREATER_THAN, argument, new Constant(rightType, calculateRangeEndInclusive(rangeLow, rightType, unit))); } }; } @@ -273,9 +273,9 @@ private Object calculateRangeEndInclusive(Object rangeStart, Type type, Supporte throw new UnsupportedOperationException("Unsupported type: " + type); } - private BetweenPredicate between(Expression argument, Type type, Object minInclusive, Object maxInclusive) + private Between between(Expression argument, Type type, Object minInclusive, Object maxInclusive) { - return new BetweenPredicate( + return new Between( argument, new Constant(type, minInclusive), new Constant(type, maxInclusive)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java index b397e07dc245..9648d9b33e5c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java @@ -20,7 +20,7 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.Subscript; import io.trino.type.UnknownType; import java.util.ArrayDeque; @@ -48,28 +48,28 @@ private static class Rewriter extends io.trino.sql.ir.ExpressionRewriter { @Override - public Expression rewriteSubscriptExpression(SubscriptExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteSubscript(Subscript node, Void context, ExpressionTreeRewriter treeRewriter) { - Expression base = treeRewriter.rewrite(node.getBase(), context); + Expression base = treeRewriter.rewrite(node.base(), context); Deque coercions = new ArrayDeque<>(); while (base instanceof Cast cast) { - if (!(cast.getType() instanceof RowType rowType)) { + if (!(cast.type() instanceof RowType rowType)) { break; } - int index = (int) (long) ((Constant) node.getIndex()).getValue(); + int index = (int) (long) ((Constant) node.index()).value(); Type type = rowType.getFields().get(index - 1).getType(); if (!(type instanceof UnknownType)) { - coercions.push(new Coercion(type, cast.isSafe())); + coercions.push(new Coercion(type, cast.safe())); } - base = cast.getExpression(); + base = cast.expression(); } if (base instanceof Row row) { - int index = (int) (long) ((Constant) node.getIndex()).getValue(); - Expression result = row.getItems().get(index - 1); + int index = (int) (long) ((Constant) node.index()).value(); + Expression result = row.items().get(index - 1); while (!coercions.isEmpty()) { Coercion coercion = coercions.pop(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java index ec0f1185ce81..7e2677a5ab8a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java @@ -20,7 +20,7 @@ import io.trino.spi.type.Type; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.ApplyNode; @@ -148,8 +148,8 @@ private Optional unwrapSingleColumnRow(Context context, Expression v Symbol valueSymbol = context.getSymbolAllocator().newSymbol("input", elementType); Symbol listSymbol = context.getSymbolAllocator().newSymbol("subquery", elementType); - Assignment inputAssignment = new Assignment(valueSymbol, new SubscriptExpression(elementType, value, new Constant(INTEGER, 1L))); - Assignment nestedPlanAssignment = new Assignment(listSymbol, new SubscriptExpression(elementType, list, new Constant(INTEGER, 1L))); + Assignment inputAssignment = new Assignment(valueSymbol, new Subscript(elementType, value, new Constant(INTEGER, 1L))); + Assignment nestedPlanAssignment = new Assignment(listSymbol, new Subscript(elementType, list, new Constant(INTEGER, 1L))); ApplyNode.SetExpression comparison = function.apply(valueSymbol, listSymbol); return Optional.of(new Unwrapping(comparison, inputAssignment, nestedPlanAssignment)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java index d10890100571..e33044b04e2b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java @@ -21,15 +21,15 @@ import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; @@ -41,11 +41,11 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.or; import static io.trino.type.DateTimes.PICOSECONDS_PER_MICROSECOND; import static io.trino.type.DateTimes.scaleFactor; @@ -104,32 +104,32 @@ public Visitor(PlannerContext plannerContext, Session session) } @Override - public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter treeRewriter) { - ComparisonExpression expression = treeRewriter.defaultRewrite(node, null); + Comparison expression = treeRewriter.defaultRewrite(node, null); return unwrapYear(expression); } @Override - public Expression rewriteInPredicate(InPredicate node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteIn(In node, Void context, ExpressionTreeRewriter treeRewriter) { - InPredicate inPredicate = treeRewriter.defaultRewrite(node, null); - Expression value = inPredicate.getValue(); + In in = treeRewriter.defaultRewrite(node, null); + Expression value = in.value(); - if (!(value instanceof FunctionCall call) || - !call.getFunction().getName().equals(builtinFunctionName("year")) || - call.getArguments().size() != 1) { - return inPredicate; + if (!(value instanceof Call call) || + !call.function().getName().equals(builtinFunctionName("year")) || + call.arguments().size() != 1) { + return in; } // Convert each value to a comparison expression and try to unwrap it. // unwrap the InPredicate only in case we manage to unwrap the entire value list - ImmutableList.Builder comparisonExpressions = ImmutableList.builderWithExpectedSize(node.getValueList().size()); - for (Expression rightExpression : node.getValueList()) { - ComparisonExpression comparisonExpression = new ComparisonExpression(EQUAL, value, rightExpression); - Expression unwrappedExpression = unwrapYear(comparisonExpression); - if (unwrappedExpression == comparisonExpression) { - return inPredicate; + ImmutableList.Builder comparisonExpressions = ImmutableList.builderWithExpectedSize(node.valueList().size()); + for (Expression rightExpression : node.valueList()) { + Comparison comparison = new Comparison(EQUAL, value, rightExpression); + Expression unwrappedExpression = unwrapYear(comparison); + if (unwrappedExpression == comparison) { + return in; } comparisonExpressions.add(unwrappedExpression); } @@ -138,26 +138,26 @@ public Expression rewriteInPredicate(InPredicate node, Void context, ExpressionT } // Simplify `year(d) ? value` - private Expression unwrapYear(ComparisonExpression expression) + private Expression unwrapYear(Comparison expression) { // Expect year on the left side and value on the right side of the comparison. // This is provided by CanonicalizeExpressionRewriter. - if (!(expression.getLeft() instanceof FunctionCall call) || - !call.getFunction().getName().equals(builtinFunctionName("year")) || - call.getArguments().size() != 1) { + if (!(expression.left() instanceof Call call) || + !call.function().getName().equals(builtinFunctionName("year")) || + call.arguments().size() != 1) { return expression; } - Expression argument = getOnlyElement(call.getArguments()); + Expression argument = getOnlyElement(call.arguments()); Type argumentType = argument.type(); - Object right = new IrExpressionInterpreter(expression.getRight(), plannerContext, session) + Object right = new IrExpressionInterpreter(expression.right(), plannerContext, session) .optimize(NoOpSymbolResolver.INSTANCE); if (right == null) { - return switch (expression.getOperator()) { + return switch (expression.operator()) { case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Constant(BOOLEAN, null); - case IS_DISTINCT_FROM -> new NotExpression(new IsNullPredicate(argument)); + case IS_DISTINCT_FROM -> new Not(new IsNull(argument)); }; } @@ -175,34 +175,34 @@ private Expression unwrapYear(ComparisonExpression expression) } int year = toIntExact((Long) right); - return switch (expression.getOperator()) { + return switch (expression.operator()) { case EQUAL -> between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType)); - case NOT_EQUAL -> new NotExpression(between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType))); + case NOT_EQUAL -> new Not(between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType))); case IS_DISTINCT_FROM -> or( - new IsNullPredicate(argument), - new NotExpression(between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType)))); + new IsNull(argument), + new Not(between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType)))); case LESS_THAN -> { Object value = calculateRangeStartInclusive(year, argumentType); - yield new ComparisonExpression(LESS_THAN, argument, new Constant(argumentType, value)); + yield new Comparison(LESS_THAN, argument, new Constant(argumentType, value)); } case LESS_THAN_OR_EQUAL -> { Object value = calculateRangeEndInclusive(year, argumentType); - yield new ComparisonExpression(LESS_THAN_OR_EQUAL, argument, new Constant(argumentType, value)); + yield new Comparison(LESS_THAN_OR_EQUAL, argument, new Constant(argumentType, value)); } case GREATER_THAN -> { Object value = calculateRangeEndInclusive(year, argumentType); - yield new ComparisonExpression(GREATER_THAN, argument, new Constant(argumentType, value)); + yield new Comparison(GREATER_THAN, argument, new Constant(argumentType, value)); } case GREATER_THAN_OR_EQUAL -> { Object value = calculateRangeStartInclusive(year, argumentType); - yield new ComparisonExpression(GREATER_THAN_OR_EQUAL, argument, new Constant(argumentType, value)); + yield new Comparison(GREATER_THAN_OR_EQUAL, argument, new Constant(argumentType, value)); } }; } - private BetweenPredicate between(Expression argument, Type type, Object minInclusive, Object maxInclusive) + private Between between(Expression argument, Type type, Object minInclusive, Object maxInclusive) { - return new BetweenPredicate( + return new Between( argument, new Constant(type, minInclusive), new Constant(type, maxInclusive)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 6c188d581b01..3407b44eb82b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -34,7 +34,7 @@ import io.trino.spi.connector.WriterScalingOptions; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningHandle; @@ -1476,7 +1476,7 @@ private static Map computeIdentityTranslations(Assignments assig { Map outputToInput = new HashMap<>(); for (Map.Entry assignment : assignments.getMap().entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { + if (assignment.getValue() instanceof Reference) { outputToInput.put(assignment.getKey(), Symbol.from(assignment.getValue())); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index a19eedf85249..e2cf0b1af2af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -23,7 +23,7 @@ import io.trino.spi.connector.WriterScalingOptions; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.PartitioningScheme; @@ -183,7 +183,7 @@ public PlanWithProperties visitProject(ProjectNode node, StreamPreferredProperti { // Special handling for trivial projections. Applies to identity and renaming projections, and constants // It might be extended to handle other low-cost projections. - if (node.getAssignments().getExpressions().stream().allMatch(expression -> expression instanceof SymbolReference || expression instanceof Constant constant && constant.getValue() != null)) { + if (node.getAssignments().getExpressions().stream().allMatch(expression -> expression instanceof Reference || expression instanceof Constant constant && constant.value() != null)) { if (parentPreferences.isSingleStreamPreferred()) { // Do not enforce gathering exchange below project: // - if project's source is single stream, no exchanges will be added around project, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index 2124dc549729..63d57da73eeb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -26,11 +26,11 @@ import io.trino.SystemSessionProperties; import io.trino.metadata.Metadata; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.BuiltinFunctionCallBuilder; import io.trino.sql.planner.Partitioning.ArgumentBinding; import io.trino.sql.planner.PartitioningHandle; @@ -862,10 +862,10 @@ public static Optional getHashExpression(Metadata metadata, SymbolAl for (Symbol symbol : symbols) { Expression hashField = BuiltinFunctionCallBuilder.resolve(metadata) .setName(HASH_CODE) - .addArgument(symbol.getType(), new SymbolReference(BIGINT, symbol.getName())) + .addArgument(symbol.getType(), new Reference(BIGINT, symbol.getName())) .build(); - hashField = new CoalesceExpression(hashField, new Constant(BIGINT, (long) NULL_HASH_CODE)); + hashField = new Coalesce(hashField, new Constant(BIGINT, (long) NULL_HASH_CODE)); result = BuiltinFunctionCallBuilder.resolve(metadata) .setName("combine_hash") @@ -921,7 +921,7 @@ private Expression getHashExpression(Metadata metadata) private static Expression getHashFunctionCall(Expression previousHashValue, Symbol symbol, Metadata metadata) { - FunctionCall functionCall = BuiltinFunctionCallBuilder.resolve(metadata) + Call call = BuiltinFunctionCallBuilder.resolve(metadata) .setName(HASH_CODE) .addArgument(symbol.getType(), symbol.toSymbolReference()) .build(); @@ -929,13 +929,13 @@ private static Expression getHashFunctionCall(Expression previousHashValue, Symb return BuiltinFunctionCallBuilder.resolve(metadata) .setName("combine_hash") .addArgument(BIGINT, previousHashValue) - .addArgument(BIGINT, orNullHashCode(functionCall)) + .addArgument(BIGINT, orNullHashCode(call)) .build(); } private static Expression orNullHashCode(Expression expression) { - return new CoalesceExpression(expression, new Constant(BIGINT, (long) NULL_HASH_CODE)); + return new Coalesce(expression, new Constant(BIGINT, (long) NULL_HASH_CODE)); } @Override @@ -999,7 +999,7 @@ private static Map computeIdentityTranslations(Map outputToInput = new HashMap<>(); for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { + if (assignment.getValue() instanceof Reference) { outputToInput.put(assignment.getKey(), Symbol.from(assignment.getValue())); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index a2c9ebef3879..ea8de6514964 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -26,9 +26,9 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -59,7 +59,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.spi.function.FunctionKind.AGGREGATE; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.plan.WindowFrameType.RANGE; import static java.util.Objects.requireNonNull; @@ -274,7 +274,7 @@ public PlanNode visitPlan(PlanNode node, RewriteContext context) @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext context) { - return planTableScan(node, BooleanLiteral.TRUE_LITERAL, context.get()); + return planTableScan(node, Booleans.TRUE, context.get()); } private PlanNode planTableScan(TableScanNode node, Expression predicate, Context context) @@ -317,7 +317,7 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context domainTranslator.toPredicate(resolvedIndex.getUnresolvedTupleDomain().transformKeys(inverseAssignments::get)), decomposedPredicate.getRemainingExpression()); - if (!resultingPredicate.equals(TRUE_LITERAL)) { + if (!resultingPredicate.equals(TRUE)) { // todo it is likely we end up with redundant filters here because the predicate push down has already been run... the fix is to run predicate push down again source = new FilterNode(idAllocator.getNextId(), source, resultingPredicate); } @@ -331,7 +331,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) // Rewrite the lookup symbols in terms of only the pre-projected symbols that have direct translations Set newLookupSymbols = context.get().getLookupSymbols().stream() .map(node.getAssignments()::get) - .filter(SymbolReference.class::isInstance) + .filter(Reference.class::isInstance) .map(Symbol::from) .collect(toImmutableSet()); @@ -480,7 +480,7 @@ protected Map visitPlan(PlanNode node, Set lookupSymbols public Map visitProject(ProjectNode node, Set lookupSymbols) { // Map from output Symbols to source Symbols - Map directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from); + Map directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), Reference.class::isInstance), Symbol::from); Map outputToSourceMap = lookupSymbols.stream() .filter(directSymbolTranslationOutputMap.keySet()::contains) .collect(toImmutableMap(identity(), directSymbolTranslationOutputMap::get)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index db981bbe9038..555417ec3dc7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -20,8 +20,8 @@ import io.trino.metadata.Metadata; import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.Type; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -206,7 +206,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext extractCorrelatedSymbolsMapping(List mapping = ImmutableMultimap.builder(); for (Expression conjunct : correlatedConjuncts) { - if (!(conjunct instanceof ComparisonExpression comparison)) { + if (!(conjunct instanceof Comparison comparison)) { continue; } - if (!(comparison.getLeft() instanceof SymbolReference - && comparison.getRight() instanceof SymbolReference - && comparison.getOperator() == EQUAL)) { + if (!(comparison.left() instanceof Reference + && comparison.right() instanceof Reference + && comparison.operator() == EQUAL)) { continue; } - Symbol left = Symbol.from(comparison.getLeft()); - Symbol right = Symbol.from(comparison.getRight()); + Symbol left = Symbol.from(comparison.left()); + Symbol right = Symbol.from(comparison.right()); if (correlation.contains(left) && !correlation.contains(right)) { mapping.put(left, right); @@ -504,18 +504,18 @@ private Set extractConstantSymbols(List correlatedConjuncts) ImmutableSet.Builder constants = ImmutableSet.builder(); correlatedConjuncts.stream() - .filter(ComparisonExpression.class::isInstance) - .map(ComparisonExpression.class::cast) - .filter(comparison -> comparison.getOperator() == EQUAL) + .filter(Comparison.class::isInstance) + .map(Comparison.class::cast) + .filter(comparison -> comparison.operator() == EQUAL) .forEach(comparison -> { - Expression left = comparison.getLeft(); - Expression right = comparison.getRight(); + Expression left = comparison.left(); + Expression right = comparison.right(); - if (!isCorrelated(left) && (left instanceof SymbolReference || isSimpleInjectiveCast(left)) && isConstant(right)) { + if (!isCorrelated(left) && (left instanceof Reference || isSimpleInjectiveCast(left)) && isConstant(right)) { constants.add(getSymbol(left)); } - if (!isCorrelated(right) && (right instanceof SymbolReference || isSimpleInjectiveCast(right)) && isConstant(left)) { + if (!isCorrelated(right) && (right instanceof Reference || isSimpleInjectiveCast(right)) && isConstant(left)) { constants.add(getSymbol(right)); } }); @@ -536,23 +536,23 @@ private boolean isSimpleInjectiveCast(Expression expression) if (!(expression instanceof Cast cast)) { return false; } - if (!(cast.getExpression() instanceof SymbolReference)) { + if (!(cast.expression() instanceof Reference)) { return false; } - Symbol sourceSymbol = Symbol.from(cast.getExpression()); + Symbol sourceSymbol = Symbol.from(cast.expression()); Type sourceType = sourceSymbol.getType(); - Type targetType = ((Cast) expression).getType(); + Type targetType = ((Cast) expression).type(); return typeCoercion.isInjectiveCoercion(sourceType, targetType); } private Symbol getSymbol(Expression expression) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return Symbol.from(expression); } - return Symbol.from(((Cast) expression).getExpression()); + return Symbol.from(((Cast) expression).expression()); } private boolean isCorrelated(Expression expression) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 2638916318b1..674a18cf9c39 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -23,13 +23,13 @@ import io.trino.metadata.Metadata; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.BooleanLiteral; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.EffectivePredicateExtractor; import io.trino.sql.planner.EqualityInference; @@ -86,13 +86,13 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.RealType.REAL; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.ir.IrUtils.filterDeterministicConjuncts; @@ -110,7 +110,7 @@ public class PredicatePushDown implements PlanOptimizer { - private static final Set DYNAMIC_FILTERING_SUPPORTED_COMPARISONS = ImmutableSet.of( + private static final Set DYNAMIC_FILTERING_SUPPORTED_COMPARISONS = ImmutableSet.of( EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL, @@ -139,7 +139,7 @@ public PlanNode optimize(PlanNode plan, Context context) return SimplePlanRewriter.rewriteWith( new Rewriter(context.symbolAllocator(), context.idAllocator(), plannerContext, context.session(), useTableProperties, dynamicFiltering), plan, - TRUE_LITERAL); + TRUE); } private static class Rewriter @@ -179,8 +179,8 @@ private Rewriter( @Override public PlanNode visitPlan(PlanNode node, RewriteContext context) { - PlanNode rewrittenNode = context.defaultRewrite(node, TRUE_LITERAL); - if (!context.get().equals(TRUE_LITERAL)) { + PlanNode rewrittenNode = context.defaultRewrite(node, TRUE); + if (!context.get().equals(TRUE)) { // Drop in a FilterNode b/c we cannot push our predicate down any further rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, context.get()); } @@ -193,7 +193,7 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext cont boolean modified = false; ImmutableList.Builder builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - Map outputsToInputs = new HashMap<>(); + Map outputsToInputs = new HashMap<>(); for (int index = 0; index < node.getInputs().get(i).size(); index++) { outputsToInputs.put( node.getOutputSymbols().get(index), @@ -303,13 +303,13 @@ private boolean isInliningCandidate(Expression expression, ProjectNode node) return dependencies.entrySet().stream() .allMatch(entry -> entry.getValue() == 1 || node.getAssignments().get(entry.getKey()) instanceof Constant - || node.getAssignments().get(entry.getKey()) instanceof SymbolReference); + || node.getAssignments().get(entry.getKey()) instanceof Reference); } @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { - Map commonGroupingSymbolMapping = node.getGroupingColumns().entrySet().stream() + Map commonGroupingSymbolMapping = node.getGroupingColumns().entrySet().stream() .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); @@ -446,8 +446,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate(); } case FULL -> { - leftPredicate = TRUE_LITERAL; - rightPredicate = TRUE_LITERAL; + leftPredicate = TRUE; + rightPredicate = TRUE; postJoinPredicate = inheritedPredicate; newJoinPredicate = joinPredicate; } @@ -472,11 +472,11 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) ImmutableList.Builder joinFilterBuilder = ImmutableList.builder(); for (Expression conjunct : extractConjuncts(newJoinPredicate)) { if (joinEqualityExpression(conjunct, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) { - ComparisonExpression equality = (ComparisonExpression) conjunct; + Comparison equality = (Comparison) conjunct; - boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(equality.getLeft())); - Expression leftExpression = alignedComparison ? equality.getLeft() : equality.getRight(); - Expression rightExpression = alignedComparison ? equality.getRight() : equality.getLeft(); + boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(equality.left())); + Expression leftExpression = alignedComparison ? equality.left() : equality.right(); + Expression rightExpression = alignedComparison ? equality.right() : equality.left(); Symbol leftSymbol = symbolForExpression(leftExpression); if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) { @@ -513,7 +513,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } Optional newJoinFilter = Optional.of(combineConjuncts(joinFilter)); - if (newJoinFilter.get().equals(TRUE_LITERAL)) { + if (newJoinFilter.get().equals(TRUE)) { newJoinFilter = Optional.empty(); } @@ -557,7 +557,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) node.getReorderJoinStatsAndCost()); } - if (!postJoinPredicate.equals(TRUE_LITERAL)) { + if (!postJoinPredicate.equals(TRUE)) { output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate); } @@ -585,25 +585,25 @@ private DynamicFiltersResult createDynamicFilters( equiJoinClauses .stream() .map(clause -> new DynamicFilterExpression( - new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()))), + new Comparison(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()))), joinFilterClauses.stream() .flatMap(Rewriter::tryConvertBetweenIntoComparisons) .filter(clause -> joinDynamicFilteringExpression(clause, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) .map(expression -> { - if (expression instanceof NotExpression notExpression) { - ComparisonExpression comparison = (ComparisonExpression) notExpression.getValue(); - return new DynamicFilterExpression(new ComparisonExpression(EQUAL, comparison.getLeft(), comparison.getRight()), true); + if (expression instanceof Not notExpression) { + Comparison comparison = (Comparison) notExpression.value(); + return new DynamicFilterExpression(new Comparison(EQUAL, comparison.left(), comparison.right()), true); } - return new DynamicFilterExpression((ComparisonExpression) expression); + return new DynamicFilterExpression((Comparison) expression); }) .map(expression -> { - ComparisonExpression comparison = expression.getComparison(); - Expression leftExpression = comparison.getLeft(); - Expression rightExpression = comparison.getRight(); + Comparison comparison = expression.getComparison(); + Expression leftExpression = comparison.left(); + Expression rightExpression = comparison.right(); boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(leftExpression)); return new DynamicFilterExpression( - new ComparisonExpression( - alignedComparison ? comparison.getOperator() : comparison.getOperator().flip(), + new Comparison( + alignedComparison ? comparison.operator() : comparison.operator().flip(), alignedComparison ? leftExpression : rightExpression, alignedComparison ? rightExpression : leftExpression), expression.isNullAllowed()); @@ -618,7 +618,7 @@ private DynamicFiltersResult createDynamicFilters( // Collect build symbols: Set buildSymbols = clauses.stream() .map(DynamicFilterExpression::getComparison) - .map(ComparisonExpression::getRight) + .map(Comparison::right) .map(Symbol::from) .collect(toImmutableSet()); @@ -634,13 +634,13 @@ private DynamicFiltersResult createDynamicFilters( List predicates = clauses .stream() .map(clause -> { - ComparisonExpression comparison = clause.getComparison(); - Expression probeExpression = comparison.getLeft(); - Symbol buildSymbol = Symbol.from(comparison.getRight()); + Comparison comparison = clause.getComparison(); + Expression probeExpression = comparison.left(); + Symbol buildSymbol = Symbol.from(comparison.right()); // we can take type of buildSymbol instead probeExpression as comparison expression must have the same type on both sides Type type = buildSymbol.getType(); DynamicFilterId id = requireNonNull(buildSymbolToDynamicFilter.get(buildSymbol), () -> "missing dynamic filter for symbol " + buildSymbol); - return createDynamicFilterExpression(metadata, id, type, probeExpression, comparison.getOperator(), clause.isNullAllowed()); + return createDynamicFilterExpression(metadata, id, type, probeExpression, comparison.operator(), clause.isNullAllowed()); }) .collect(toImmutableList()); // Return a mapping from build symbols to corresponding dynamic filter IDs: @@ -649,31 +649,31 @@ private DynamicFiltersResult createDynamicFilters( private static Stream tryConvertBetweenIntoComparisons(Expression clause) { - if (clause instanceof BetweenPredicate between) { + if (clause instanceof Between between) { return Stream.of( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, between.getValue(), between.getMin()), - new ComparisonExpression(LESS_THAN_OR_EQUAL, between.getValue(), between.getMax())); + new Comparison(GREATER_THAN_OR_EQUAL, between.value(), between.min()), + new Comparison(LESS_THAN_OR_EQUAL, between.value(), between.max())); } return Stream.of(clause); } private static class DynamicFilterExpression { - private final ComparisonExpression comparison; + private final Comparison comparison; private final boolean nullAllowed; - private DynamicFilterExpression(ComparisonExpression comparison) + private DynamicFilterExpression(Comparison comparison) { this(comparison, false); } - private DynamicFilterExpression(ComparisonExpression comparison, boolean nullAllowed) + private DynamicFilterExpression(Comparison comparison, boolean nullAllowed) { this.comparison = requireNonNull(comparison, "comparison is null"); this.nullAllowed = nullAllowed; } - public ComparisonExpression getComparison() + public Comparison getComparison() { return comparison; } @@ -756,7 +756,7 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext private boolean joinDynamicFilteringExpression(Expression expression, Collection leftSymbols, Collection rightSymbols) { - ComparisonExpression comparison; - if (expression instanceof NotExpression notExpression) { - boolean isDistinctFrom = joinComparisonExpression(notExpression.getValue(), leftSymbols, rightSymbols, ImmutableSet.of(IS_DISTINCT_FROM)); + Comparison comparison; + if (expression instanceof Not not) { + boolean isDistinctFrom = joinComparisonExpression(not.value(), leftSymbols, rightSymbols, ImmutableSet.of(IS_DISTINCT_FROM)); if (!isDistinctFrom) { return false; } - comparison = (ComparisonExpression) notExpression.getValue(); + comparison = (Comparison) not.value(); Set expressionTypes = ImmutableSet.of( comparison.left().type(), comparison.right().type()); @@ -1223,21 +1223,21 @@ private boolean joinDynamicFilteringExpression(Expression expression, Collection if (!joinComparisonExpression(expression, leftSymbols, rightSymbols, DYNAMIC_FILTERING_SUPPORTED_COMPARISONS)) { return false; } - comparison = (ComparisonExpression) expression; + comparison = (Comparison) expression; } // Build side expression must be a symbol reference, since DynamicFilterSourceOperator can only collect column values (not expressions) - return (comparison.getRight() instanceof SymbolReference && rightSymbols.contains(Symbol.from(comparison.getRight()))) - || (comparison.getLeft() instanceof SymbolReference && rightSymbols.contains(Symbol.from(comparison.getLeft()))); + return (comparison.right() instanceof Reference && rightSymbols.contains(Symbol.from(comparison.right()))) + || (comparison.left() instanceof Reference && rightSymbols.contains(Symbol.from(comparison.left()))); } - private boolean joinComparisonExpression(Expression expression, Collection leftSymbols, Collection rightSymbols, Set operators) + private boolean joinComparisonExpression(Expression expression, Collection leftSymbols, Collection rightSymbols, Set operators) { // At this point in time, our join predicates need to be deterministic - if (expression instanceof ComparisonExpression comparison && isDeterministic(expression)) { - if (operators.contains(comparison.getOperator())) { - Set symbols1 = extractUnique(comparison.getLeft()); - Set symbols2 = extractUnique(comparison.getRight()); + if (expression instanceof Comparison comparison && isDeterministic(expression)) { + if (operators.contains(comparison.operator())) { + Set symbols1 = extractUnique(comparison.left()); + Set symbols2 = extractUnique(comparison.right()); if (symbols1.isEmpty() || symbols2.isEmpty()) { return false; } @@ -1266,7 +1266,7 @@ private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext sourceScope = ImmutableSet.copyOf(node.getSource().getOutputSymbols()); @@ -1316,7 +1316,7 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext co { Expression predicate = simplifyExpression(context.get()); - if (!TRUE_LITERAL.equals(predicate)) { + if (!TRUE.equals(predicate)) { return new FilterNode(idAllocator.getNextId(), node, predicate); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index e146b3c51e67..afb19c0bdb69 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -30,9 +30,9 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; @@ -776,8 +776,8 @@ public ActualProperties visitProject(ProjectNode node, List in // ("ROW comparison not supported for fields with null elements", etc) Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); - if (value instanceof SymbolReference) { - Symbol symbol = Symbol.from((SymbolReference) value); + if (value instanceof Reference) { + Symbol symbol = Symbol.from((Reference) value); NullableValue existingConstantValue = constants.get(symbol); if (existingConstantValue != null) { constants.put(assignment.getKey(), new NullableValue(type, value)); @@ -910,7 +910,7 @@ private static Map computeIdentityTranslations(Map inputToOutput = new HashMap<>(); for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { + if (assignment.getValue() instanceof Reference) { inputToOutput.put(Symbol.from(assignment.getValue()), assignment.getKey()); } } @@ -967,12 +967,12 @@ public static Optional filterOrRewrite(Collection columns, Colle private static Optional rewriteExpression(Map assignments, Expression expression) { // Only simple coalesce expressions supported currently - if (!(expression instanceof CoalesceExpression)) { + if (!(expression instanceof Coalesce)) { return Optional.empty(); } - Set arguments = ImmutableSet.copyOf(((CoalesceExpression) expression).getOperands()); - if (!arguments.stream().allMatch(SymbolReference.class::isInstance)) { + Set arguments = ImmutableSet.copyOf(((Coalesce) expression).operands()); + if (!arguments.stream().allMatch(Reference.class::isInstance)) { return Optional.empty(); } @@ -980,9 +980,9 @@ private static Optional rewriteExpression(Map assign // of the arguments. Thus we extract and compare the symbols of the CoalesceExpression as a set rather than compare the // CoalesceExpression directly. for (Map.Entry entry : assignments.entrySet()) { - if (entry.getValue() instanceof CoalesceExpression) { - Set candidateArguments = ImmutableSet.copyOf(((CoalesceExpression) entry.getValue()).getOperands()); - if (!candidateArguments.stream().allMatch(SymbolReference.class::isInstance)) { + if (entry.getValue() instanceof Coalesce) { + Set candidateArguments = ImmutableSet.copyOf(((Coalesce) entry.getValue()).operands()); + if (!candidateArguments.stream().allMatch(Reference.class::isInstance)) { return Optional.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index 6dc5bb4f2900..ab181c9e58f2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -27,7 +27,7 @@ import io.trino.spi.connector.LocalProperty; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Partitioning.ArgumentBinding; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.AggregationNode; @@ -400,7 +400,7 @@ private static Map computeIdentityTranslations(Map inputToOutput = new HashMap<>(); for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { + if (assignment.getValue() instanceof Reference) { inputToOutput.put(Symbol.from(assignment.getValue()), assignment.getKey()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 97857d163be7..d1b9cd0052d8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -19,8 +19,8 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.Symbol; @@ -152,22 +152,22 @@ public Expression map(Expression expression) return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<>() { @Override - public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteReference(Reference node, Void context, ExpressionTreeRewriter treeRewriter) { Symbol canonical = map(Symbol.from(node)); return canonical.toSymbolReference(); } @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLambda(Lambda node, Void context, ExpressionTreeRewriter treeRewriter) { - List arguments = node.getArguments().stream() + List arguments = node.arguments().stream() .map(symbol -> map(new Symbol(symbol.getType(), symbol.getName()))) .collect(toImmutableList()); - Expression body = treeRewriter.rewrite(node.getBody(), context); - if (body != node.getBody()) { - return new LambdaExpression(arguments, body); + Expression body = treeRewriter.rewrite(node.body(), context); + if (body != node.body()) { + return new Lambda(arguments, body); } return node; @@ -356,7 +356,7 @@ private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndV .map(expression -> ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() { @Override - public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteReference(Reference node, Void context, ExpressionTreeRewriter treeRewriter) { if (pointer.getClassifierSymbol().isPresent() && Symbol.from(node).equals(pointer.getClassifierSymbol().get()) || pointer.getMatchNumberSymbol().isPresent() && Symbol.from(node).equals(pointer.getMatchNumberSymbol().get())) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index 9d09630a8486..951ba36480c0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -18,11 +18,11 @@ import io.trino.metadata.Metadata; import io.trino.spi.type.BigintType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Case; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -46,14 +46,14 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.combineDisjuncts; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; @@ -163,7 +163,7 @@ countNonNullValue, new Aggregation( subqueryPlan, node.getCorrelation(), JoinType.INNER, - TRUE_LITERAL, + TRUE, node.getOriginSubquery()); Expression valueComparedToSubquery = rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue); @@ -178,26 +178,26 @@ public Expression rewriteUsingBounds(ApplyNode.QuantifiedComparison quantifiedCo Constant emptySetResult; Function, Expression> quantifier; if (quantifiedComparison.quantifier() == ALL) { - emptySetResult = TRUE_LITERAL; + emptySetResult = TRUE; quantifier = expressions -> combineConjuncts(expressions); } else { - emptySetResult = FALSE_LITERAL; + emptySetResult = FALSE; quantifier = expressions -> combineDisjuncts(expressions); } Expression comparisonWithExtremeValue = getBoundComparisons(quantifiedComparison, minValue, maxValue); - return new SimpleCaseExpression( + return new Switch( countAllValue.toSymbolReference(), ImmutableList.of(new WhenClause( new Constant(BIGINT, 0L), emptySetResult)), Optional.of(quantifier.apply(ImmutableList.of( comparisonWithExtremeValue, - new SearchedCaseExpression( + new Case( ImmutableList.of( new WhenClause( - new ComparisonExpression(NOT_EQUAL, countAllValue.toSymbolReference(), countNonNullValue.toSymbolReference()), + new Comparison(NOT_EQUAL, countAllValue.toSymbolReference(), countNonNullValue.toSymbolReference()), new Constant(BOOLEAN, null))), Optional.of(emptySetResult)))))); } @@ -207,8 +207,8 @@ private Expression getBoundComparisons(ApplyNode.QuantifiedComparison quantified if (mapOperator(quantifiedComparison) == EQUAL && quantifiedComparison.quantifier() == ALL) { // A = ALL B <=> min B = max B && A = min B return combineConjuncts( - new ComparisonExpression(EQUAL, minValue.toSymbolReference(), maxValue.toSymbolReference()), - new ComparisonExpression(EQUAL, quantifiedComparison.value().toSymbolReference(), maxValue.toSymbolReference())); + new Comparison(EQUAL, minValue.toSymbolReference(), maxValue.toSymbolReference()), + new Comparison(EQUAL, quantifiedComparison.value().toSymbolReference(), maxValue.toSymbolReference())); } if (EnumSet.of(LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL).contains(mapOperator(quantifiedComparison))) { @@ -217,12 +217,12 @@ private Expression getBoundComparisons(ApplyNode.QuantifiedComparison quantified // A < ANY B <=> A < max B // A > ANY B <=> A > min B Symbol boundValue = shouldCompareValueWithLowerBound(quantifiedComparison) ? minValue : maxValue; - return new ComparisonExpression(mapOperator(quantifiedComparison), quantifiedComparison.value().toSymbolReference(), boundValue.toSymbolReference()); + return new Comparison(mapOperator(quantifiedComparison), quantifiedComparison.value().toSymbolReference(), boundValue.toSymbolReference()); } throw new IllegalArgumentException("Unsupported quantified comparison: " + quantifiedComparison); } - private static ComparisonExpression.Operator mapOperator(ApplyNode.QuantifiedComparison quantifiedComparison) + private static Comparison.Operator mapOperator(ApplyNode.QuantifiedComparison quantifiedComparison) { return switch (quantifiedComparison.operator()) { case EQUAL -> EQUAL; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 4ee1735745a5..268aebeaa80c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -22,10 +22,10 @@ import io.trino.cost.PlanNodeStatsEstimate; import io.trino.spi.connector.ColumnHandle; import io.trino.sql.DynamicFilters; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.DeterminismEvaluator; import io.trino.sql.planner.NodeAndMappings; import io.trino.sql.planner.OrderingScheme; @@ -611,7 +611,7 @@ public PlanAndMappings visitValues(ValuesNode node, UnaliasContext context) for (int i = 0; i < node.getOutputSymbols().size(); i++) { ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); for (Expression row : node.getRows().get()) { - expressionsBuilder.add(mapper.map(((Row) row).getItems().get(i))); + expressionsBuilder.add(mapper.map(((Row) row).items().get(i))); } rewrittenAssignmentsBuilder.add(new SimpleEntry<>(mapper.map(node.getOutputSymbols().get(i)), expressionsBuilder.build())); } @@ -908,7 +908,7 @@ private Map mappingFromAssignments(Map assig // If the assignment potentially introduces a reused (ambiguous) symbol, do not map output to input // to avoid mixing semantics. Input symbols represent semantics as in the source plan, // while output symbols represent newly established semantics. - if (expression instanceof SymbolReference && !ambiguousSymbolsPresent) { + if (expression instanceof Reference && !ambiguousSymbolsPresent) { Symbol value = Symbol.from(expression); if (!assignment.getKey().equals(value)) { newMapping.put(assignment.getKey(), value); @@ -1445,7 +1445,7 @@ private Expression updateDynamicFilterIds(Map Expression newConjunct = conjunct; if (mappedId != null) { // DF was remapped - newConjunct = replaceDynamicFilterId((FunctionCall) conjunct, mappedId); + newConjunct = replaceDynamicFilterId((Call) conjunct, mappedId); updated = true; } newConjuncts.add(newConjunct); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java index 00ff730d4f64..9a31ff50c44e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java @@ -21,7 +21,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -200,7 +200,7 @@ private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Sym extractionResult.getRemainingExpression(), domainTranslator.toPredicate(newTupleDomain)); - if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { + if (newPredicate.equals(Booleans.TRUE)) { return source; } return new FilterNode(filterNode.getId(), source, newPredicate); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java index 067422ddc931..a08356043a9d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java @@ -25,8 +25,8 @@ import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.type.FunctionType; @@ -407,7 +407,7 @@ public Aggregation( this.resolvedFunction = requireNonNull(resolvedFunction, "resolvedFunction is null"); this.arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments is null")); for (Expression argument : arguments) { - checkArgument(argument instanceof SymbolReference || argument instanceof LambdaExpression, + checkArgument(argument instanceof Reference || argument instanceof Lambda, "argument must be symbol or lambda expression: %s", argument.getClass().getSimpleName()); } this.distinct = distinct; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java index bb5bd286769e..5d9b7d2ce447 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java @@ -20,7 +20,7 @@ import com.google.common.collect.Maps; import io.trino.spi.type.Type; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -132,7 +132,7 @@ public boolean isIdentity(Symbol output) { Expression expression = assignments.get(output); - return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); + return expression instanceof Reference && ((Reference) expression).name().equals(output.getName()); } public boolean isIdentity() @@ -140,7 +140,7 @@ public boolean isIdentity() for (Map.Entry entry : assignments.entrySet()) { Expression expression = entry.getValue(); Symbol symbol = entry.getKey(); - if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { + if (!(expression instanceof Reference && ((Reference) expression).name().equals(symbol.getName()))) { return false; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java index 12ab2ef03465..b6096557df89 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java @@ -20,7 +20,7 @@ import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.Immutable; import io.trino.cost.PlanNodeStatsAndCostSummary; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.planner.Symbol; @@ -345,9 +345,9 @@ public Symbol getRight() return right; } - public ComparisonExpression toExpression() + public Comparison toExpression() { - return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, left.toSymbolReference(), right.toSymbolReference()); + return new Comparison(Comparison.Operator.EQUAL, left.toSymbolReference(), right.toSymbolReference()); } public EquiJoinClause flip() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/SetOperationNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/SetOperationNode.java index 5b4e03b7fdcf..3b746ab4dbd6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/SetOperationNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/SetOperationNode.java @@ -25,7 +25,7 @@ import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.errorprone.annotations.Immutable; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import java.util.Collection; @@ -105,9 +105,9 @@ public List sourceOutputLayout(int sourceIndex) /** * Returns the output to input symbol mapping for the given source channel */ - public Map sourceSymbolMap(int sourceIndex) + public Map sourceSymbolMap(int sourceIndex) { - ImmutableMap.Builder builder = ImmutableMap.builder(); + ImmutableMap.Builder builder = ImmutableMap.builder(); for (Map.Entry> entry : outputToInputs.asMap().entrySet()) { builder.put(entry.getKey(), Iterables.get(entry.getValue(), sourceIndex).toSymbolReference()); } @@ -119,7 +119,7 @@ public Map sourceSymbolMap(int sourceIndex) * Returns the input to output symbol mapping for the given source channel. * A single input symbol can map to multiple output symbols, thus requiring a Multimap. */ - public Multimap outputSymbolMap(int sourceIndex) + public Multimap outputSymbolMap(int sourceIndex) { return Multimaps.transformValues(FluentIterable.from(getOutputSymbols()) .toMap(outputToSourceSymbolFunction(sourceIndex)) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ValuesNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ValuesNode.java index 578dc565927d..27330c6c2faa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ValuesNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ValuesNode.java @@ -77,7 +77,7 @@ public ValuesNode( List rowSizes = rows.get().stream() .map(row -> requireNonNull(row, "row is null")) .filter(expression -> expression instanceof Row) - .map(expression -> ((Row) expression).getItems().size()) + .map(expression -> ((Row) expression).items().size()) .distinct() .collect(toImmutableList()); checkState(rowSizes.size() <= 1, "mismatched rows. All rows must be the same size"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java index 101204f1d4fe..13be7eccebdc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java @@ -25,7 +25,7 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionFormatter; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SystemPartitioningHandle; @@ -99,20 +99,20 @@ public String anonymize(Expression expression) return anonymizeExpressionFormatter.process(expression); } - private String anonymizeSymbolReference(SymbolReference node) + private String anonymizeSymbolReference(Reference node) { return '"' + anonymize(Symbol.from(node)) + '"'; } private String anonymizeLiteral(Constant literal) { - if (literal.getValue() == null) { + if (literal.value() == null) { return "null"; } - if (literal.getType().equals(BOOLEAN)) { - return literal.getValue().toString(); + if (literal.type().equals(BOOLEAN)) { + return literal.value().toString(); } - return anonymizeLiteral(literal.getType().getDisplayName(), literal.getValue()); + return anonymizeLiteral(literal.type().getDisplayName(), literal.value()); } private String anonymizeLiteral(String type, T value) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java index eadbe209fd51..89e892c2ecb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java @@ -16,9 +16,9 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Partitioning.ArgumentBinding; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.SubPlan; @@ -72,7 +72,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Maps.immutableEnumMap; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static io.trino.sql.planner.planprinter.PlanPrinter.formatAggregation; import static java.lang.String.format; @@ -390,8 +390,8 @@ public Void visitProject(ProjectNode node, Void context) { StringBuilder builder = new StringBuilder(); for (Map.Entry entry : node.getAssignments().entrySet()) { - if ((entry.getValue() instanceof SymbolReference) && - ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { + if ((entry.getValue() instanceof Reference) && + ((Reference) entry.getValue()).name().equals(entry.getKey().getName())) { // skip identity assignments continue; } @@ -544,7 +544,7 @@ public Void visitCorrelatedJoin(CorrelatedJoinNode node, Void context) { String correlationSymbols = Joiner.on(",").join(node.getCorrelation()); String filterExpression = ""; - if (!node.getFilter().equals(TRUE_LITERAL)) { + if (!node.getFilter().equals(TRUE)) { filterExpression = " " + node.getFilter().toString(); } @@ -568,7 +568,7 @@ public Void visitIndexJoin(IndexJoinNode node, Void context) { List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { - joinExpressions.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, + joinExpressions.add(new Comparison(Comparison.Operator.EQUAL, clause.getProbe().toSymbolReference(), clause.getIndex().toSymbolReference())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java index ab5ce776d4e0..bcb9a1465c73 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import io.airlift.json.JsonCodec; import io.trino.cost.PlanNodeStatsAndCostSummary; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.PlanNodeId; import java.util.List; @@ -24,7 +25,6 @@ import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.sql.planner.planprinter.NodeRepresentation.TypedSymbol; import static java.util.Objects.requireNonNull; public class JsonRenderer @@ -69,7 +69,7 @@ public static class JsonRenderedNode private final String id; private final String name; private final Map descriptor; - private final List outputs; + private final List outputs; private final List details; private final List estimates; private final List children; @@ -78,7 +78,7 @@ public JsonRenderedNode( String id, String name, Map descriptor, - List outputs, + List outputs, List details, List estimates, List children) @@ -111,7 +111,7 @@ public Map getDescriptor() } @JsonProperty - public List getOutputs() + public List getOutputs() { return outputs; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java index 1faa518bbc0b..64aadcd90ee3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java @@ -13,25 +13,20 @@ */ package io.trino.sql.planner.planprinter; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.FormatMethod; import io.trino.cost.LocalCostEstimate; import io.trino.cost.PlanCostEstimate; import io.trino.cost.PlanNodeStatsAndCostSummary; import io.trino.cost.PlanNodeStatsEstimate; -import io.trino.spi.type.Type; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.PlanNodeId; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -41,7 +36,7 @@ public class NodeRepresentation private final String name; private final String type; private final Map descriptor; - private final List outputs; + private final List outputs; private final List children; private final List initialChildren; private final Optional stats; @@ -56,7 +51,7 @@ public NodeRepresentation( String name, String type, Map descriptor, - List outputs, + List outputs, Optional stats, List estimatedStats, List estimatedCost, @@ -111,7 +106,7 @@ public Map getDescriptor() return descriptor; } - public List getOutputs() + public List getOutputs() { return outputs; } @@ -163,13 +158,9 @@ public List getEstimates() PlanNodeStatsEstimate stats = getEstimatedStats().get(i); LocalCostEstimate cost = getEstimatedCost().get(i).getRootNodeLocalCostEstimate(); - List outputSymbols = getOutputs().stream() - .map(NodeRepresentation.TypedSymbol::getSymbol) - .collect(toImmutableList()); - estimates.add(new PlanNodeStatsAndCostSummary( stats.getOutputRowCount(), - stats.getOutputSizeInBytes(outputSymbols), + stats.getOutputSizeInBytes(getOutputs()), cost.getCpuCost(), cost.getMaxMemory(), cost.getNetworkCost())); @@ -177,60 +168,4 @@ public List getEstimates() return estimates.build(); } - - @Deprecated // TODO: replace with Symbol now that it carries a type - public static class TypedSymbol - { - private final Symbol symbol; - private final Type trinoType; - - public TypedSymbol(Symbol symbol, Type trinoType) - { - this.symbol = symbol; - this.trinoType = trinoType; - } - - @JsonProperty - public Symbol getSymbol() - { - return symbol; - } - - @JsonProperty - public String getType() - { - return trinoType.getDisplayName(); - } - - @JsonIgnore - public Type getTrinoType() - { - return trinoType; - } - - public static TypedSymbol typedSymbol(String symbol, Type type) - { - return new TypedSymbol(new Symbol(type, symbol), type); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (!(o instanceof TypedSymbol)) { - return false; - } - TypedSymbol that = (TypedSymbol) o; - return symbol.equals(that.symbol) - && trinoType.equals(that.trinoType); - } - - @Override - public int hashCode() - { - return Objects.hash(symbol, trinoType); - } - } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 741782291948..b482d9c2d49b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -53,10 +53,10 @@ import io.trino.spi.statistics.TableStatisticType; import io.trino.spi.type.Type; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; @@ -126,7 +126,6 @@ import io.trino.sql.planner.plan.UnnestNode; import io.trino.sql.planner.plan.ValuesNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.planner.planprinter.NodeRepresentation.TypedSymbol; import io.trino.sql.planner.rowpattern.AggregationValuePointer; import io.trino.sql.planner.rowpattern.ClassifierValuePointer; import io.trino.sql.planner.rowpattern.ExpressionAndValuePointers; @@ -162,7 +161,7 @@ import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import static io.trino.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static io.trino.sql.DynamicFilters.extractDynamicFilters; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjunctsWithDuplicates; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.plan.JoinType.INNER; @@ -741,7 +740,7 @@ public Void visitIndexJoin(IndexJoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { - joinExpressions.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, + joinExpressions.add(new Comparison(Comparison.Operator.EQUAL, clause.getProbe().toSymbolReference(), clause.getIndex().toSymbolReference())); } @@ -1148,7 +1147,7 @@ public Void visitValues(ValuesNode node, Context context) List rows = node.getRows().get().stream() .map(row -> { if (row instanceof Row) { - return ((Row) row).getItems().stream() + return ((Row) row).items().stream() .map(anonymizer::anonymize) .collect(joining(", ", "(", ")")); } @@ -1921,7 +1920,7 @@ private Void processChildren(PlanNode node, Context context) private void printAssignments(NodeRepresentation nodeOutput, Assignments assignments) { for (Entry entry : assignments.getMap().entrySet()) { - if (entry.getValue() instanceof SymbolReference && ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { + if (entry.getValue() instanceof Reference && ((Reference) entry.getValue()).name().equals(entry.getKey().getName())) { // skip identity assignments continue; } @@ -2019,7 +2018,7 @@ private String formatDomain(Domain domain) private String formatFilter(Expression filter) { - return filter.equals(TRUE_LITERAL) ? "" : anonymizer.anonymize(filter); + return filter.equals(TRUE) ? "" : anonymizer.anonymize(filter); } private String formatBoolean(boolean value) @@ -2129,7 +2128,7 @@ public NodeRepresentation addNode( rootNode.getClass().getSimpleName(), descriptor, rootNode.getOutputSymbols().stream() - .map(s -> new TypedSymbol(new Symbol(s.getType(), anonymizer.anonymize(s)), s.getType())) + .map(s -> new Symbol(s.getType(), anonymizer.anonymize(s))) .collect(toImmutableList()), stats.map(s -> s.get(rootNode.getId())), estimatedStats, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java index 24d3dce81f5f..781da29a7eeb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java @@ -75,7 +75,7 @@ private String writeTextOutput(StringBuilder output, PlanRepresentation plan, In .append("\n"); String columns = node.getOutputs().stream() - .map(s -> s.getSymbol().getName() + ":" + s.getType()) + .map(s -> s.getName() + ":" + s.getType()) .collect(joining(", ")); output.append(indentMultilineString("Layout: [" + columns + "]\n", indent.detailIndent())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/ExpressionAndValuePointers.java b/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/ExpressionAndValuePointers.java index 9513d58d8f70..abe1f24e86fd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/ExpressionAndValuePointers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/ExpressionAndValuePointers.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Expression; import io.trino.sql.planner.Symbol; @@ -25,12 +26,11 @@ import java.util.stream.Collectors; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static java.util.Objects.requireNonNull; public class ExpressionAndValuePointers { - public static final ExpressionAndValuePointers TRUE = new ExpressionAndValuePointers(TRUE_LITERAL, ImmutableList.of()); + public static final ExpressionAndValuePointers TRUE = new ExpressionAndValuePointers(Booleans.TRUE, ImmutableList.of()); private final Expression expression; private final List assignments; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/DynamicFiltersChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/DynamicFiltersChecker.java index bc24289f741f..b9618168c7c1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/DynamicFiltersChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/DynamicFiltersChecker.java @@ -23,7 +23,7 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrUtils; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.DynamicFilterSourceNode; import io.trino.sql.planner.plan.FilterNode; @@ -170,14 +170,14 @@ public Set visitDynamicFilterSource(DynamicFilterSourceNode nod private static void validateDynamicFilterExpression(Expression expression) { - if (expression instanceof SymbolReference) { + if (expression instanceof Reference) { return; } verify(expression instanceof Cast, "Dynamic filter expression %s must be a SymbolReference or a CAST of SymbolReference.", expression); Cast castExpression = (Cast) expression; - verify(castExpression.getExpression() instanceof SymbolReference, - "The expression %s within in a CAST in dynamic filter must be a SymbolReference.", formatExpression(castExpression.getExpression())); + verify(castExpression.expression() instanceof Reference, + "The expression %s within in a CAST in dynamic filter must be a SymbolReference.", formatExpression(castExpression.expression())); } private static List extractDynamicPredicates(Expression expression) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java index 56d5571b647c..ba8fac7956aa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java @@ -21,7 +21,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.SimplePlanVisitor; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.AggregationNode; @@ -98,8 +98,8 @@ public Void visitProject(ProjectNode node, Void context) for (Map.Entry entry : node.getAssignments().entrySet()) { Type expectedType = entry.getKey().getType(); - if (entry.getValue() instanceof SymbolReference symbolReference) { - Symbol symbol = Symbol.from(symbolReference); + if (entry.getValue() instanceof Reference reference) { + Symbol symbol = Symbol.from(reference); verifyTypeSignature(entry.getKey(), expectedType, symbol.getType()); continue; } diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index 7af086fa7bcf..d82ddf5e70c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -21,29 +21,29 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.BindExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Bind; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.ComparisonExpression.Operator; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Comparison.Operator; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; +import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.Symbol; import io.trino.sql.relational.SpecialForm.Form; @@ -137,17 +137,17 @@ protected RowExpression visitExpression(Expression node, Void context) @Override protected RowExpression visitConstant(Constant node, Void context) { - return constant(node.getValue(), node.getType()); + return constant(node.value(), node.type()); } @Override - protected RowExpression visitComparisonExpression(ComparisonExpression node, Void context) + protected RowExpression visitComparison(Comparison node, Void context) { - RowExpression left = process(node.getLeft(), context); - RowExpression right = process(node.getRight(), context); - Operator operator = node.getOperator(); + RowExpression left = process(node.left(), context); + RowExpression right = process(node.right(), context); + Operator operator = node.operator(); - switch (node.getOperator()) { + switch (node.operator()) { case NOT_EQUAL: return new CallExpression( metadata.resolveBuiltinFunction("not", fromTypes(BOOLEAN)), @@ -170,76 +170,76 @@ private RowExpression visitComparisonExpression(Operator operator, RowExpression } @Override - protected RowExpression visitFunctionCall(FunctionCall node, Void context) + protected RowExpression visitCall(Call node, Void context) { - List arguments = node.getArguments().stream() + List arguments = node.arguments().stream() .map(value -> process(value, context)) .collect(toImmutableList()); - return new CallExpression(node.getFunction(), arguments); + return new CallExpression(node.function(), arguments); } @Override - protected RowExpression visitSymbolReference(SymbolReference node, Void context) + protected RowExpression visitReference(Reference node, Void context) { Integer field = layout.get(Symbol.from(node)); if (field != null) { return field(field, ((Expression) node).type()); } - return new VariableReferenceExpression(node.getName(), ((Expression) node).type()); + return new VariableReferenceExpression(node.name(), ((Expression) node).type()); } @Override - protected RowExpression visitLambdaExpression(LambdaExpression node, Void context) + protected RowExpression visitLambda(Lambda node, Void context) { return new LambdaDefinitionExpression( node.arguments(), - process(node.getBody(), context)); + process(node.body(), context)); } @Override - protected RowExpression visitBindExpression(BindExpression node, Void context) + protected RowExpression visitBind(Bind node, Void context) { ImmutableList.Builder valueTypesBuilder = ImmutableList.builder(); ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); - for (Expression value : node.getValues()) { + for (Expression value : node.values()) { RowExpression valueRowExpression = process(value, context); valueTypesBuilder.add(valueRowExpression.getType()); argumentsBuilder.add(valueRowExpression); } - RowExpression function = process(node.getFunction(), context); + RowExpression function = process(node.function(), context); argumentsBuilder.add(function); return new SpecialForm(BIND, ((Expression) node).type(), argumentsBuilder.build()); } @Override - protected RowExpression visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + protected RowExpression visitArithmetic(Arithmetic node, Void context) { - RowExpression left = process(node.getLeft(), context); - RowExpression right = process(node.getRight(), context); + RowExpression left = process(node.left(), context); + RowExpression right = process(node.right(), context); return call( - standardFunctionResolution.arithmeticFunction(node.getOperator(), left.getType(), right.getType()), + standardFunctionResolution.arithmeticFunction(node.operator(), left.getType(), right.getType()), left, right); } @Override - protected RowExpression visitArithmeticNegation(ArithmeticNegation node, Void context) + protected RowExpression visitNegation(Negation node, Void context) { - RowExpression expression = process(node.getValue(), context); + RowExpression expression = process(node.value(), context); return call( metadata.resolveOperator(NEGATION, ImmutableList.of(expression.getType())), expression); } @Override - protected RowExpression visitLogicalExpression(LogicalExpression node, Void context) + protected RowExpression visitLogical(Logical node, Void context) { Form form; - switch (node.getOperator()) { + switch (node.operator()) { case AND: form = AND; break; @@ -247,12 +247,12 @@ protected RowExpression visitLogicalExpression(LogicalExpression node, Void cont form = OR; break; default: - throw new IllegalStateException("Unknown logical operator: " + node.getOperator()); + throw new IllegalStateException("Unknown logical operator: " + node.operator()); } return new SpecialForm( form, BOOLEAN, - node.getTerms().stream() + node.terms().stream() .map(term -> process(term, context)) .collect(toImmutableList())); } @@ -260,14 +260,14 @@ protected RowExpression visitLogicalExpression(LogicalExpression node, Void cont @Override protected RowExpression visitCast(Cast node, Void context) { - RowExpression value = process(node.getExpression(), context); + RowExpression value = process(node.expression(), context); Type returnType = ((Expression) node).type(); if (typeCoercion.isTypeOnlyCoercion(value.getType(), returnType)) { return changeType(value, returnType); } - if (node.isSafe()) { + if (node.safe()) { return call( metadata.getCoercion(builtinFunctionName("TRY_CAST"), value.getType(), returnType), value); @@ -332,9 +332,9 @@ public RowExpression visitVariableReference(VariableReferenceExpression referenc } @Override - protected RowExpression visitCoalesceExpression(CoalesceExpression node, Void context) + protected RowExpression visitCoalesce(Coalesce node, Void context) { - List arguments = node.getOperands().stream() + List arguments = node.operands().stream() .map(value -> process(value, context)) .collect(toImmutableList()); @@ -342,15 +342,15 @@ protected RowExpression visitCoalesceExpression(CoalesceExpression node, Void co } @Override - protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Void context) + protected RowExpression visitSwitch(Switch node, Void context) { ImmutableList.Builder arguments = ImmutableList.builder(); - RowExpression value = process(node.getOperand(), context); + RowExpression value = process(node.operand(), context); arguments.add(value); ImmutableList.Builder functionDependencies = ImmutableList.builder(); - for (WhenClause clause : node.getWhenClauses()) { + for (WhenClause clause : node.whenClauses()) { RowExpression operand = process(clause.getOperand(), context); RowExpression result = process(clause.getResult(), context); @@ -365,7 +365,7 @@ protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Voi Type returnType = ((Expression) node).type(); - arguments.add(node.getDefaultValue() + arguments.add(node.defaultValue() .map(defaultValue -> process(defaultValue, context)) .orElse(constantNull(returnType))); @@ -373,7 +373,7 @@ protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Voi } @Override - protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, Void context) + protected RowExpression visitCase(Case node, Void context) { /* Translates an expression like: @@ -395,11 +395,11 @@ protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, value4))) */ - RowExpression expression = node.getDefaultValue() + RowExpression expression = node.defaultValue() .map(value -> process(value, context)) .orElse(constantNull(((Expression) node).type())); - for (WhenClause clause : node.getWhenClauses().reversed()) { + for (WhenClause clause : node.whenClauses().reversed()) { expression = new SpecialForm( IF, ((Expression) node).type(), @@ -412,12 +412,12 @@ protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, } @Override - protected RowExpression visitInPredicate(InPredicate node, Void context) + protected RowExpression visitIn(In node, Void context) { ImmutableList.Builder arguments = ImmutableList.builder(); - RowExpression value = process(node.getValue(), context); + RowExpression value = process(node.value(), context); arguments.add(value); - for (Expression testValue : node.getValueList()) { + for (Expression testValue : node.valueList()) { arguments.add(process(testValue, context)); } @@ -431,17 +431,17 @@ protected RowExpression visitInPredicate(InPredicate node, Void context) } @Override - protected RowExpression visitIsNullPredicate(IsNullPredicate node, Void context) + protected RowExpression visitIsNull(IsNull node, Void context) { - RowExpression expression = process(node.getValue(), context); + RowExpression expression = process(node.value(), context); return new SpecialForm(IS_NULL, BOOLEAN, expression); } @Override - protected RowExpression visitNotExpression(NotExpression node, Void context) + protected RowExpression visitNot(Not node, Void context) { - return notExpression(process(node.getValue(), context)); + return notExpression(process(node.value(), context)); } private RowExpression notExpression(RowExpression value) @@ -452,10 +452,10 @@ private RowExpression notExpression(RowExpression value) } @Override - protected RowExpression visitNullIfExpression(NullIfExpression node, Void context) + protected RowExpression visitNullIf(NullIf node, Void context) { - RowExpression first = process(node.getFirst(), context); - RowExpression second = process(node.getSecond(), context); + RowExpression first = process(node.first(), context); + RowExpression second = process(node.second(), context); ResolvedFunction resolvedFunction = metadata.resolveOperator(EQUAL, ImmutableList.of(first.getType(), second.getType())); List functionDependencies = ImmutableList.builder() @@ -472,11 +472,11 @@ protected RowExpression visitNullIfExpression(NullIfExpression node, Void contex } @Override - protected RowExpression visitBetweenPredicate(BetweenPredicate node, Void context) + protected RowExpression visitBetween(Between node, Void context) { - RowExpression value = process(node.getValue(), context); - RowExpression min = process(node.getMin(), context); - RowExpression max = process(node.getMax(), context); + RowExpression value = process(node.value(), context); + RowExpression min = process(node.min(), context); + RowExpression max = process(node.max(), context); List functionDependencies = ImmutableList.of( metadata.resolveOperator(LESS_THAN_OR_EQUAL, ImmutableList.of(value.getType(), max.getType()))); @@ -489,12 +489,12 @@ protected RowExpression visitBetweenPredicate(BetweenPredicate node, Void contex } @Override - protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void context) + protected RowExpression visitSubscript(Subscript node, Void context) { - RowExpression base = process(node.getBase(), context); - RowExpression index = process(node.getIndex(), context); + RowExpression base = process(node.base(), context); + RowExpression index = process(node.index(), context); - if (node.getBase().type() instanceof RowType) { + if (node.base().type() instanceof RowType) { long value = (Long) ((ConstantExpression) index).getValue(); return new SpecialForm(DEREFERENCE, ((Expression) node).type(), base, constant(value - 1, INTEGER)); } @@ -508,7 +508,7 @@ protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void @Override protected RowExpression visitRow(Row node, Void context) { - List arguments = node.getItems().stream() + List arguments = node.items().stream() .map(value -> process(value, context)) .collect(toImmutableList()); Type returnType = ((Expression) node).type(); diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java b/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java index 89c03dcc7897..18fa35ff36ef 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/StandardFunctionResolution.java @@ -18,8 +18,8 @@ import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ArithmeticBinaryExpression.Operator; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic.Operator; +import io.trino.sql.ir.Comparison; import static io.trino.spi.function.OperatorType.ADD; import static io.trino.spi.function.OperatorType.DIVIDE; @@ -66,7 +66,7 @@ public ResolvedFunction arithmeticFunction(Operator operator, Type leftType, Typ return metadata.resolveOperator(operatorType, ImmutableList.of(leftType, rightType)); } - public ResolvedFunction comparisonFunction(ComparisonExpression.Operator operator, Type leftType, Type rightType) + public ResolvedFunction comparisonFunction(Comparison.Operator operator, Type leftType, Type rightType) { OperatorType operatorType; switch (operator) { diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java index 738839bbe2dc..d1d0a254f614 100644 --- a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java @@ -27,7 +27,7 @@ import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; @@ -82,7 +82,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.LogicalPlanner.buildLambdaDeclarationToSymbolMap; import static io.trino.sql.relational.Expressions.call; import static io.trino.sql.relational.Expressions.constantNull; @@ -404,13 +404,13 @@ public TranslationVisitor( } @Override - protected RowExpression visitSymbolReference(SymbolReference node, Void context) + protected RowExpression visitReference(Reference node, Void context) { - IrVariable variable = variables.get(node.getName()); + IrVariable variable = variables.get(node.name()); if (variable != null) { return field(variable.field(), variable.type()); } - return super.visitSymbolReference(node, context); + return super.visitReference(node, context); } } } diff --git a/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java b/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java index 606fc0f77c03..07d5863f19f2 100644 --- a/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java @@ -14,9 +14,9 @@ package io.trino.util; import io.trino.spi.function.CatalogSchemaFunctionName; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import java.util.List; @@ -41,18 +41,18 @@ private SpatialJoinUtils() {} *

* Doesn't check or guarantee anything about function arguments. */ - public static List extractSupportedSpatialFunctions(Expression filterExpression) + public static List extractSupportedSpatialFunctions(Expression filterExpression) { return extractConjuncts(filterExpression).stream() - .filter(FunctionCall.class::isInstance) - .map(FunctionCall.class::cast) + .filter(Call.class::isInstance) + .map(Call.class::cast) .filter(SpatialJoinUtils::isSupportedSpatialFunction) .collect(toImmutableList()); } - private static boolean isSupportedSpatialFunction(FunctionCall functionCall) + private static boolean isSupportedSpatialFunction(Call call) { - CatalogSchemaFunctionName functionName = functionCall.getFunction().getName(); + CatalogSchemaFunctionName functionName = call.function().getName(); return functionName.equals(builtinFunctionName(ST_CONTAINS)) || functionName.equals(builtinFunctionName(ST_WITHIN)) || functionName.equals(builtinFunctionName(ST_INTERSECTS)); @@ -68,24 +68,24 @@ private static boolean isSupportedSpatialFunction(FunctionCall functionCall) * Doesn't check or guarantee anything about ST_Distance functions arguments * or the other side of the comparison. */ - public static List extractSupportedSpatialComparisons(Expression filterExpression) + public static List extractSupportedSpatialComparisons(Expression filterExpression) { return extractConjuncts(filterExpression).stream() - .filter(ComparisonExpression.class::isInstance) - .map(ComparisonExpression.class::cast) + .filter(Comparison.class::isInstance) + .map(Comparison.class::cast) .filter(SpatialJoinUtils::isSupportedSpatialComparison) .collect(toImmutableList()); } - private static boolean isSupportedSpatialComparison(ComparisonExpression expression) + private static boolean isSupportedSpatialComparison(Comparison expression) { - switch (expression.getOperator()) { + switch (expression.operator()) { case LESS_THAN: case LESS_THAN_OR_EQUAL: - return isSTDistance(expression.getLeft()); + return isSTDistance(expression.left()); case GREATER_THAN: case GREATER_THAN_OR_EQUAL: - return isSTDistance(expression.getRight()); + return isSTDistance(expression.right()); default: return false; } @@ -93,8 +93,8 @@ private static boolean isSupportedSpatialComparison(ComparisonExpression express private static boolean isSTDistance(Expression expression) { - if (expression instanceof FunctionCall call) { - return call.getFunction().getName().equals(builtinFunctionName(ST_DISTANCE)); + if (expression instanceof Call call) { + return call.function().getName().equals(builtinFunctionName(ST_DISTANCE)); } return false; diff --git a/core/trino-main/src/test/java/io/trino/cost/TestAggregationStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestAggregationStatsRule.java index 8c43fa5cfa77..1a1a02b3b799 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestAggregationStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestAggregationStatsRule.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.AggregationNode; import org.junit.jupiter.api.Test; @@ -95,9 +95,9 @@ private StatsCalculatorAssertion testAggregation(SymbolStatsEstimate zStats) { return tester().assertStatsFor(pb -> pb .aggregation(ab -> ab - .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .addAggregation(pb.symbol("count", BIGINT), aggregation("count", ImmutableList.of()), ImmutableList.of()) - .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() @@ -144,7 +144,7 @@ public void testAggregationStatsCappedToInputRows() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab - .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() @@ -160,8 +160,8 @@ public void testAggregationWithGlobalGrouping() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab - .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) - .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.unknown()) @@ -173,8 +173,8 @@ public void testAggregationWithMoreGroupingSets() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab - .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) - .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .groupingSets(new AggregationNode.GroupingSetDescriptor(ImmutableList.of(pb.symbol("y"), pb.symbol("z")), 3, ImmutableSet.of(0))) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() diff --git a/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java index 906f54f7e854..113f0e0f6a1b 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java @@ -15,10 +15,10 @@ import io.airlift.slice.Slices; import io.trino.Session; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import org.junit.jupiter.api.Test; @@ -33,12 +33,12 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.lang.Double.NEGATIVE_INFINITY; @@ -202,7 +202,7 @@ public void verifyTestInputConsistent() public void symbolToLiteralEqualStats() { // Simple case - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) .outputRowsCount(25.0) // all rows minus nulls divided by distinct values count .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -213,7 +213,7 @@ public void symbolToLiteralEqualStats() }); // Literal on the edge of symbol range - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) .outputRowsCount(18.75) // all rows minus nulls divided by distinct values count .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -224,7 +224,7 @@ public void symbolToLiteralEqualStats() }); // Literal out of symbol range - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 10.0))) .outputRowsCount(0.0) // all rows minus nulls divided by distinct values count .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(0.0) @@ -234,7 +234,7 @@ public void symbolToLiteralEqualStats() }); // Literal in left open range - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 2.5))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 2.5))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count .symbolStats("leftOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -245,7 +245,7 @@ public void symbolToLiteralEqualStats() }); // Literal in right open range - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "rightOpen"), new Constant(DOUBLE, -2.5))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "rightOpen"), new Constant(DOUBLE, -2.5))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count .symbolStats("rightOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -256,7 +256,7 @@ public void symbolToLiteralEqualStats() }); // Literal in unknown range - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count .symbolStats("unknownRange", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -267,12 +267,12 @@ public void symbolToLiteralEqualStats() }); // Literal in empty range - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(0.0) .symbolStats("emptyRange", equalTo(emptyRangeStats)); // Column with values not representable as double (unknown range) - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "varchar"), new Constant(VARCHAR, Slices.utf8Slice("blah")))) + assertCalculate(new Comparison(EQUAL, new Reference(VARCHAR, "varchar"), new Constant(VARCHAR, Slices.utf8Slice("blah")))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count .symbolStats("varchar", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -287,7 +287,7 @@ public void symbolToLiteralEqualStats() public void symbolToLiteralNotEqualStats() { // Simple case - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) .outputRowsCount(475.0) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -298,7 +298,7 @@ public void symbolToLiteralNotEqualStats() }); // Literal on the edge of symbol range - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) .outputRowsCount(731.25) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -309,7 +309,7 @@ public void symbolToLiteralNotEqualStats() }); // Literal out of symbol range - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 10.0))) .outputRowsCount(500.0) // all rows minus nulls .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -320,7 +320,7 @@ public void symbolToLiteralNotEqualStats() }); // Literal in left open range - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 2.5))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 2.5))) .outputRowsCount(882.0) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) .symbolStats("leftOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -331,7 +331,7 @@ public void symbolToLiteralNotEqualStats() }); // Literal in right open range - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "rightOpen"), new Constant(DOUBLE, -2.5))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "rightOpen"), new Constant(DOUBLE, -2.5))) .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count .symbolStats("rightOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -342,7 +342,7 @@ public void symbolToLiteralNotEqualStats() }); // Literal in unknown range - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count .symbolStats("unknownRange", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -353,12 +353,12 @@ public void symbolToLiteralNotEqualStats() }); // Literal in empty range - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(0.0) .symbolStats("emptyRange", equalTo(emptyRangeStats)); // Column with values not representable as double (unknown range) - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(VARCHAR, "varchar"), new Constant(VARCHAR, Slices.utf8Slice("blah")))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(VARCHAR, "varchar"), new Constant(VARCHAR, Slices.utf8Slice("blah")))) .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count .symbolStats("varchar", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -373,7 +373,7 @@ public void symbolToLiteralNotEqualStats() public void symbolToLiteralLessThanStats() { // Simple case - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) .outputRowsCount(250.0) // all rows minus nulls times range coverage (50%) .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -384,7 +384,7 @@ public void symbolToLiteralLessThanStats() }); // Literal on the edge of symbol range (whole range included) - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) .outputRowsCount(750.0) // all rows minus nulls times range coverage (100%) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -395,7 +395,7 @@ public void symbolToLiteralLessThanStats() }); // Literal on the edge of symbol range (whole range excluded) - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -10.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, -10.0))) .outputRowsCount(18.75) // all rows minus nulls divided by NDV (one value from edge is included as approximation) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -406,7 +406,7 @@ public void symbolToLiteralLessThanStats() }); // Literal range out of symbol range - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, -10.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Constant(DOUBLE, -10.0))) .outputRowsCount(0.0) // all rows minus nulls times range coverage (0%) .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(0.0) @@ -416,7 +416,7 @@ public void symbolToLiteralLessThanStats() }); // Literal in left open range - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 0.0))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) .symbolStats("leftOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -427,7 +427,7 @@ public void symbolToLiteralLessThanStats() }); // Literal in right open range - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "rightOpen"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "rightOpen"), new Constant(DOUBLE, 0.0))) .outputRowsCount(225.0) // all rows minus nulls times range coverage (25% - heuristic) .symbolStats("rightOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -438,7 +438,7 @@ public void symbolToLiteralLessThanStats() }); // Literal in unknown range - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) .symbolStats("unknownRange", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -449,7 +449,7 @@ public void symbolToLiteralLessThanStats() }); // Literal in empty range - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(0.0) .symbolStats("emptyRange", equalTo(emptyRangeStats)); } @@ -458,7 +458,7 @@ public void symbolToLiteralLessThanStats() public void symbolToLiteralGreaterThanStats() { // Simple case - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 2.5))) .outputRowsCount(250.0) // all rows minus nulls times range coverage (50%) .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -469,7 +469,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal on the edge of symbol range (whole range included) - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -10.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, -10.0))) .outputRowsCount(750.0) // all rows minus nulls times range coverage (100%) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -480,7 +480,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal on the edge of symbol range (whole range excluded) - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 10.0))) .outputRowsCount(18.75) // all rows minus nulls divided by NDV (one value from edge is included as approximation) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -491,7 +491,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal range out of symbol range - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 10.0))) .outputRowsCount(0.0) // all rows minus nulls times range coverage (0%) .symbolStats("y", symbolAssert -> { symbolAssert.averageRowSize(0.0) @@ -501,7 +501,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal in left open range - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "leftOpen"), new Constant(DOUBLE, 0.0))) .outputRowsCount(225.0) // all rows minus nulls times range coverage (25% - heuristic) .symbolStats("leftOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -512,7 +512,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal in right open range - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "rightOpen"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "rightOpen"), new Constant(DOUBLE, 0.0))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) .symbolStats("rightOpen", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -523,7 +523,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal in unknown range - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) .symbolStats("unknownRange", symbolAssert -> { symbolAssert.averageRowSize(4.0) @@ -534,7 +534,7 @@ public void symbolToLiteralGreaterThanStats() }); // Literal in empty range - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "emptyRange"), new Constant(DOUBLE, 0.0))) .outputRowsCount(0.0) .symbolStats("emptyRange", equalTo(emptyRangeStats)); } @@ -545,7 +545,7 @@ public void symbolToSymbolEqualStats() // z's stats should be unchanged when not involved, except NDV capping to row count // Equal ranges double rowCount = 2.7; - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "u"), new SymbolReference(INTEGER, "w"))) + assertCalculate(new Comparison(EQUAL, new Reference(INTEGER, "u"), new Reference(INTEGER, "w"))) .outputRowsCount(rowCount) .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) @@ -553,7 +553,7 @@ public void symbolToSymbolEqualStats() // One symbol's range is within the other's rowCount = 9.375; - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "y"))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "y"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(4) @@ -573,7 +573,7 @@ public void symbolToSymbolEqualStats() // Partially overlapping ranges rowCount = 16.875; - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "w"))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "w"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(6) @@ -593,7 +593,7 @@ public void symbolToSymbolEqualStats() // None of the ranges is included in the other, and one symbol has much higher cardinality, so that it has bigger NDV in intersect than the other in total rowCount = 2.25; - assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "u"))) + assertCalculate(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "u"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> { symbolAssert.averageRowSize(6) @@ -617,7 +617,7 @@ public void symbolToSymbolNotEqual() { // Equal ranges double rowCount = 807.3; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "u"), new SymbolReference(DOUBLE, "w"))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "u"), new Reference(DOUBLE, "w"))) .outputRowsCount(rowCount) .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) @@ -625,7 +625,7 @@ public void symbolToSymbolNotEqual() // One symbol's range is within the other's rowCount = 365.625; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "y"))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "y"))) .outputRowsCount(rowCount) .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) .symbolStats("y", equalTo(capNDV(zeroNullsFraction(yStats), rowCount))) @@ -633,7 +633,7 @@ public void symbolToSymbolNotEqual() // Partially overlapping ranges rowCount = 658.125; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "w"))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "w"))) .outputRowsCount(rowCount) .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) @@ -641,7 +641,7 @@ public void symbolToSymbolNotEqual() // None of the ranges is included in the other, and one symbol has much higher cardinality, so that it has bigger NDV in intersect than the other in total rowCount = 672.75; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "u"))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "u"))) .outputRowsCount(rowCount) .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) @@ -652,7 +652,7 @@ public void symbolToSymbolNotEqual() public void symbolToCastExpressionNotEqual() { double rowCount = 897.0; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "u"), new Constant(DOUBLE, 10.0))) + assertCalculate(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "u"), new Constant(DOUBLE, 10.0))) .outputRowsCount(rowCount) .symbolStats("u", equalTo(capNDV(updateNDV(zeroNullsFraction(uStats), -1), rowCount))) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); @@ -667,7 +667,7 @@ public void symbolToSymbolInequalityStats() double nullsFractionX = 0.25; double rowCount = inputRowCount * (1 - nullsFractionX); // Same symbol on both sides of inequality, gets simplified to x IS NOT NULL - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "x"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "x"))) .outputRowsCount(rowCount) .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))); @@ -675,7 +675,7 @@ public void symbolToSymbolInequalityStats() double nonNullRowCount = inputRowCount * (1 - nullsFractionU); rowCount = nonNullRowCount * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT; // Equal ranges - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "u"), new SymbolReference(DOUBLE, "w"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "u"), new Reference(DOUBLE, "w"))) .outputRowsCount(rowCount) .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) @@ -687,7 +687,7 @@ public void symbolToSymbolInequalityStats() nonNullRowCount = inputRowCount * (1 - nullsFractionY); rowCount = nonNullRowCount * (alwaysLesserFractionX + (overlappingFractionX * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT)); // One symbol's range is within the other's - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "y"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "y"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(-10) @@ -700,7 +700,7 @@ public void symbolToSymbolInequalityStats() .distinctValuesCount(20) .nullsFraction(0)) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); - assertCalculate(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "y"))) + assertCalculate(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "y"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(-10) @@ -714,7 +714,7 @@ public void symbolToSymbolInequalityStats() .nullsFraction(0)) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); // Flip symbols to be on opposite sides - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new SymbolReference(DOUBLE, "x"))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Reference(DOUBLE, "x"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(-10) @@ -730,7 +730,7 @@ public void symbolToSymbolInequalityStats() double alwaysGreaterFractionX = 0.25; rowCount = nonNullRowCount * (alwaysGreaterFractionX + overlappingFractionX * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT); - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "y"))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "y"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(0) @@ -743,7 +743,7 @@ public void symbolToSymbolInequalityStats() .distinctValuesCount(20) .nullsFraction(0)) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); - assertCalculate(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "y"))) + assertCalculate(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "y"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(0) @@ -757,7 +757,7 @@ public void symbolToSymbolInequalityStats() .nullsFraction(0)) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); // Flip symbols to be on opposite sides - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new SymbolReference(DOUBLE, "x"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Reference(DOUBLE, "x"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(0) @@ -778,7 +778,7 @@ public void symbolToSymbolInequalityStats() double alwaysGreaterFractionW = 0.5; rowCount = nonNullRowCount * (alwaysLesserFractionX + overlappingFractionX * (overlappingFractionW * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT + alwaysGreaterFractionW)); - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "w"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "w"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(-10) @@ -792,7 +792,7 @@ public void symbolToSymbolInequalityStats() .nullsFraction(0)) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); // Flip symbols to be on opposite sides - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "w"), new SymbolReference(DOUBLE, "x"))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "w"), new Reference(DOUBLE, "x"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(-10) @@ -807,7 +807,7 @@ public void symbolToSymbolInequalityStats() .symbolStats("z", equalTo(capNDV(zStats, rowCount))); rowCount = nonNullRowCount * (overlappingFractionX * overlappingFractionW * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT); - assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "w"))) + assertCalculate(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "w"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(0) @@ -821,7 +821,7 @@ public void symbolToSymbolInequalityStats() .nullsFraction(0)) .symbolStats("z", equalTo(capNDV(zStats, rowCount))); // Flip symbols to be on opposite sides - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "w"), new SymbolReference(DOUBLE, "x"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "w"), new Reference(DOUBLE, "x"))) .outputRowsCount(rowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(0) @@ -844,7 +844,7 @@ public void symbolToSymbolInequalityStats() double alwaysGreaterFractionRight = 0.5; rowCount = nonNullRowCount * (alwaysLesserFractionLeft + overlappingFractionLeft * (overlappingFractionRight * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT + alwaysGreaterFractionRight)); - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "leftOpen"), new SymbolReference(DOUBLE, "rightOpen"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "leftOpen"), new Reference(DOUBLE, "rightOpen"))) .outputRowsCount(rowCount) .symbolStats("leftOpen", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(NEGATIVE_INFINITY) @@ -859,7 +859,7 @@ public void symbolToSymbolInequalityStats() .symbolStats("z", equalTo(capNDV(zStats, rowCount))); rowCount = nonNullRowCount * (alwaysLesserFractionLeft + overlappingFractionLeft * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT); - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "leftOpen"), new SymbolReference(DOUBLE, "unknownNdvRange"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "leftOpen"), new Reference(DOUBLE, "unknownNdvRange"))) .outputRowsCount(rowCount) .symbolStats("leftOpen", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(NEGATIVE_INFINITY) @@ -874,7 +874,7 @@ public void symbolToSymbolInequalityStats() .symbolStats("z", equalTo(capNDV(zStats, rowCount))); rowCount = nonNullRowCount * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT; - assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "leftOpen"), new SymbolReference(DOUBLE, "unknownRange"))) + assertCalculate(new Comparison(LESS_THAN, new Reference(DOUBLE, "leftOpen"), new Reference(DOUBLE, "unknownRange"))) .outputRowsCount(rowCount) .symbolStats("leftOpen", symbolAssert -> symbolAssert.averageRowSize(4) .lowValue(NEGATIVE_INFINITY) diff --git a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java index f668ca127aa6..237e26c16ae7 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java @@ -30,8 +30,8 @@ import io.trino.sql.PlannerContext; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanFragmenter; import io.trino.sql.planner.SubPlan; @@ -156,7 +156,7 @@ public void testTableScan() public void testProject() { TableScanNode tableScan = tableScan("ts", new Symbol(BIGINT, "orderkey")); - PlanNode project = project("project", tableScan, new Symbol(VARCHAR, "string"), new Cast(new SymbolReference(BIGINT, "orderkey"), VARCHAR)); + PlanNode project = project("project", tableScan, new Symbol(VARCHAR, "string"), new Cast(new Reference(BIGINT, "orderkey"), VARCHAR)); Map costs = ImmutableMap.of("ts", cpuCost(1000)); Map stats = ImmutableMap.of( "project", statsEstimate(project, 4000), @@ -184,7 +184,7 @@ public void testProject() public void testFilter() { TableScanNode tableScan = tableScan("ts", new Symbol(VARCHAR, "string")); - IsNullPredicate expression = new IsNullPredicate(new SymbolReference(VARCHAR, "string")); + IsNull expression = new IsNull(new Reference(VARCHAR, "string")); FilterNode filter = new FilterNode(new PlanNodeId("filter"), tableScan, expression); Map costs = ImmutableMap.of("ts", cpuCost(1000)); Map stats = ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java index 1a29b2d91ec3..2e12fac74e16 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java @@ -18,10 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; @@ -35,9 +35,9 @@ import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.iterative.rule.test.PlanBuilder.aggregation; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.type.UnknownType.UNKNOWN; @@ -74,9 +74,9 @@ public class TestFilterProjectAggregationStatsRule public void testFilterOverAggregationStats() { Function planProvider = pb -> pb.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), pb.aggregation(ab -> ab - .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))); @@ -103,9 +103,9 @@ public void testFilterOverAggregationStats() // If filter estimate is known, approximation should not be applied tester().assertStatsFor(APPROXIMATION_ENABLED, pb -> pb.filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "y"), new Constant(INTEGER, 1L)), pb.aggregation(ab -> ab - .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count_on_x", BIGINT), aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))))) .withSourceStats(sourceStats) @@ -126,11 +126,11 @@ public void testFilterAndProjectOverAggregationStats() pb -> { Symbol aggregatedOutput = pb.symbol("count_on_x", BIGINT); return pb.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), // Narrowing identity projection pb.project(Assignments.identity(aggregatedOutput), pb.aggregation(ab -> ab - .addAggregation(aggregatedOutput, aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(aggregatedOutput, aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))) .nodeId(aggregationId)))); @@ -144,11 +144,11 @@ public void testFilterAndProjectOverAggregationStats() pb -> { Symbol aggregatedOutput = pb.symbol("count_on_x", BIGINT); return pb.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), // Non-narrowing projection - pb.project(Assignments.of(pb.symbol("x_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 1L)), aggregatedOutput, aggregatedOutput.toSymbolReference()), + pb.project(Assignments.of(pb.symbol("x_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)), aggregatedOutput, aggregatedOutput.toSymbolReference()), pb.aggregation(ab -> ab - .addAggregation(aggregatedOutput, aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(aggregatedOutput, aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))) .nodeId(aggregationId)))); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index ac1443493d68..ca09adcb058a 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -26,20 +26,20 @@ import io.trino.spi.type.Decimals; import io.trino.spi.type.VarcharType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.transaction.TestingTransactionManager; import io.trino.transaction.TransactionManager; @@ -55,18 +55,18 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.TransactionBuilder.transaction; @@ -177,8 +177,8 @@ public class TestFilterStatsCalculator @Test public void testBooleanLiteralStats() { - assertExpression(TRUE_LITERAL).equalTo(standardInputStatistics); - assertExpression(FALSE_LITERAL).equalTo(zeroStatistics); + assertExpression(TRUE).equalTo(standardInputStatistics); + assertExpression(FALSE).equalTo(zeroStatistics); assertExpression(new Constant(BOOLEAN, null)).equalTo(zeroStatistics); } @@ -186,7 +186,7 @@ public void testBooleanLiteralStats() public void testComparison() { double lessThan3Rows = 487.5; - assertExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))) + assertExpression(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))) .outputRowsCount(lessThan3Rows) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -195,14 +195,14 @@ public void testComparison() .distinctValuesCount(26) .nullsFraction(0.0)); - assertExpression(new ComparisonExpression(GREATER_THAN, new ArithmeticNegation(new SymbolReference(DOUBLE, "x")), new Constant(DOUBLE, -3.0))) + assertExpression(new Comparison(GREATER_THAN, new Negation(new Reference(DOUBLE, "x")), new Constant(DOUBLE, -3.0))) .outputRowsCount(lessThan3Rows); for (Expression minusThree : ImmutableList.of( new Constant(createDecimalType(3), Decimals.valueOfShort(new BigDecimal("-3"))), new Constant(DOUBLE, -3.0), - new ArithmeticBinaryExpression(SUBTRACT_DOUBLE, SUBTRACT, new Constant(DOUBLE, 4.0), new Constant(DOUBLE, 7.0)), new Cast(new Constant(INTEGER, -3L), createDecimalType(7, 3)))) { - assertExpression(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Cast(minusThree, DOUBLE))) + new Arithmetic(SUBTRACT_DOUBLE, SUBTRACT, new Constant(DOUBLE, 4.0), new Constant(DOUBLE, 7.0)), new Cast(new Constant(INTEGER, -3L), createDecimalType(7, 3)))) { + assertExpression(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Cast(minusThree, DOUBLE))) .outputRowsCount(18.75) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -211,7 +211,7 @@ public void testComparison() .distinctValuesCount(1) .nullsFraction(0.0)); - assertExpression(new ComparisonExpression(EQUAL, new Cast(minusThree, DOUBLE), new SymbolReference(DOUBLE, "x"))) + assertExpression(new Comparison(EQUAL, new Cast(minusThree, DOUBLE), new Reference(DOUBLE, "x"))) .outputRowsCount(18.75) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -220,11 +220,11 @@ public void testComparison() .distinctValuesCount(1) .nullsFraction(0.0)); - assertExpression(new ComparisonExpression( + assertExpression(new Comparison( EQUAL, - new CoalesceExpression( - new ArithmeticBinaryExpression(MULTIPLY_DOUBLE, MULTIPLY, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, null)), - new SymbolReference(DOUBLE, "x")), + new Coalesce( + new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Reference(DOUBLE, "x"), new Constant(DOUBLE, null)), + new Reference(DOUBLE, "x")), new Cast(minusThree, DOUBLE))) .outputRowsCount(18.75) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> @@ -234,7 +234,7 @@ public void testComparison() .distinctValuesCount(1) .nullsFraction(0.0)); - assertExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Cast(minusThree, DOUBLE))) + assertExpression(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Cast(minusThree, DOUBLE))) .outputRowsCount(262.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -243,7 +243,7 @@ public void testComparison() .distinctValuesCount(14) .nullsFraction(0.0)); - assertExpression(new ComparisonExpression(GREATER_THAN, new Cast(minusThree, DOUBLE), new SymbolReference(DOUBLE, "x"))) + assertExpression(new Comparison(GREATER_THAN, new Cast(minusThree, DOUBLE), new Reference(DOUBLE, "x"))) .outputRowsCount(262.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -257,32 +257,32 @@ public void testComparison() @Test public void testInequalityComparisonApproximation() { - assertExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "emptyRange"))) + assertExpression(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "emptyRange"))) .outputRowsCount(0); - assertExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 20L)))) + assertExpression(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 20L)))) .outputRowsCount(0); - assertExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 20L)))) + assertExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 20L)))) .outputRowsCount(0); - assertExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 25L)))) + assertExpression(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "y"), new Constant(INTEGER, 25L)))) .outputRowsCount(0); - assertExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 25L)))) + assertExpression(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "y"), new Constant(INTEGER, 25L)))) .outputRowsCount(0); double nullsFractionY = 0.5; double inputRowCount = standardInputStatistics.getOutputRowCount(); double nonNullRowCount = inputRowCount * (1 - nullsFractionY); SymbolStatsEstimate nonNullStatsX = xStats.mapNullsFraction(nullsFraction -> 0.0); - assertExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 25L)))) + assertExpression(new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "y"), new Constant(INTEGER, 25L)))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); - assertExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 25L)))) + assertExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "y"), new Constant(INTEGER, 25L)))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); - assertExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 20L)))) + assertExpression(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 20L)))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); - assertExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 20L)))) + assertExpression(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 20L)))) .outputRowsCount(nonNullRowCount) .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); } @@ -290,7 +290,7 @@ public void testInequalityComparisonApproximation() @Test public void testOrStats() { - assertExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -7.5))))) + assertExpression(new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, -7.5))))) .outputRowsCount(375) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -299,7 +299,7 @@ public void testOrStats() .distinctValuesCount(20.0) .nullsFraction(0.0)); - assertExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -7.5))))) + assertExpression(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, -7.5))))) .outputRowsCount(37.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -308,7 +308,7 @@ public void testOrStats() .distinctValuesCount(2.0) .nullsFraction(0.0)); - assertExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)), new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))))) + assertExpression(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))))) .outputRowsCount(37.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -317,7 +317,7 @@ public void testOrStats() .distinctValuesCount(2) .nullsFraction(0)); - assertExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)), new ComparisonExpression(EQUAL, new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b"))), new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))))) + assertExpression(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)), new Comparison(EQUAL, new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b"))), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))))) .outputRowsCount(37.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -326,16 +326,16 @@ public void testOrStats() .distinctValuesCount(2) .nullsFraction(0)); - assertExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)), new InPredicate(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3)), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), createVarcharType(3)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3)))), new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))))) + assertExpression(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)), new In(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3)), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), createVarcharType(3)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3)))), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 3.0))))) .equalTo(standardInputStatistics); } @Test public void testUnsupportedExpression() { - assertExpression(new FunctionCall(SIN, ImmutableList.of(new SymbolReference(DOUBLE, "x")))) + assertExpression(new Call(SIN, ImmutableList.of(new Reference(DOUBLE, "x")))) .outputRowsCountUnknown(); - assertExpression(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new FunctionCall(SIN, ImmutableList.of(new SymbolReference(DOUBLE, "x"))))) + assertExpression(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Call(SIN, ImmutableList.of(new Reference(DOUBLE, "x"))))) .outputRowsCountUnknown(); } @@ -343,15 +343,15 @@ public void testUnsupportedExpression() public void testAndStats() { // unknown input - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)))), PlanNodeStatsEstimate.unknown()).outputRowsCountUnknown(); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 1.0)))), PlanNodeStatsEstimate.unknown()).outputRowsCountUnknown(); + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)))), PlanNodeStatsEstimate.unknown()).outputRowsCountUnknown(); + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 1.0)))), PlanNodeStatsEstimate.unknown()).outputRowsCountUnknown(); // zeroStatistics input - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)))), zeroStatistics).equalTo(zeroStatistics); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 1.0)))), zeroStatistics).equalTo(zeroStatistics); + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0)))), zeroStatistics).equalTo(zeroStatistics); + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Constant(DOUBLE, 1.0)))), zeroStatistics).equalTo(zeroStatistics); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0))))).equalTo(zeroStatistics); + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0))))).equalTo(zeroStatistics); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -7.5))))) + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, -7.5))))) .outputRowsCount(281.25) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -361,13 +361,13 @@ public void testAndStats() .nullsFraction(0.0)); // Impossible, with symbol-to-expression comparisons - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(ADD_DOUBLE, ADD, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 1.0))), new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new ArithmeticBinaryExpression(ADD_DOUBLE, ADD, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 3.0)))))) + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_DOUBLE, ADD, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 1.0))), new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Arithmetic(ADD_DOUBLE, ADD, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 3.0)))))) .outputRowsCount(0) .symbolStats(new Symbol(UNKNOWN, "x"), SymbolStatsAssertion::emptyRange) .symbolStats(new Symbol(UNKNOWN, "y"), SymbolStatsAssertion::emptyRange); // first argument unknown - assertExpression(new LogicalExpression(AND, ImmutableList.of(new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference(DOUBLE, "x"))), new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0))))) + assertExpression(new Logical(AND, ImmutableList.of(new Call(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new Reference(DOUBLE, "x"))), new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0))))) .outputRowsCount(337.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.lowValue(-10) @@ -376,7 +376,7 @@ public void testAndStats() .nullsFraction(0)); // second argument unknown - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference(DOUBLE, "x")))))) + assertExpression(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Call(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new Reference(DOUBLE, "x")))))) .outputRowsCount(337.5) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.lowValue(-10) @@ -385,16 +385,16 @@ public void testAndStats() .nullsFraction(0)); // both arguments unknown - assertExpression(new LogicalExpression(AND, ImmutableList.of( - new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[11]"))), new SymbolReference(DOUBLE, "x"))), - new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[13]"))), new SymbolReference(DOUBLE, "x")))))) + assertExpression(new Logical(AND, ImmutableList.of( + new Call(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[11]"))), new Reference(DOUBLE, "x"))), + new Call(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[13]"))), new Reference(DOUBLE, "x")))))) .outputRowsCountUnknown(); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new InPredicate(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c")))), new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 3.0))))) + assertExpression(new Logical(AND, ImmutableList.of(new In(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c")))), new Comparison(EQUAL, new Reference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 3.0))))) .outputRowsCount(0); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new Constant(BOOLEAN, null), new Constant(BOOLEAN, null)))).equalTo(zeroStatistics); - assertExpression(new LogicalExpression(AND, ImmutableList.of(new Constant(BOOLEAN, null), new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 1.0))))))).equalTo(zeroStatistics); + assertExpression(new Logical(AND, ImmutableList.of(new Constant(BOOLEAN, null), new Constant(BOOLEAN, null)))).equalTo(zeroStatistics); + assertExpression(new Logical(AND, ImmutableList.of(new Constant(BOOLEAN, null), new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)), new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 1.0))))))).equalTo(zeroStatistics); Consumer symbolAssertX = symbolAssert -> symbolAssert.averageRowSize(4.0) .lowValue(-5.0) @@ -411,27 +411,27 @@ public void testAndStats() double filterSelectivityX = 0.375; double inequalityFilterSelectivityY = 0.4; assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), + new Logical(AND, ImmutableList.of( + new Between(new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0").build()) .outputRowsCount(filterSelectivityX * inputRowCount) .symbolStats("x", symbolAssertX) .symbolStats("y", symbolAssertY); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), + new Logical(AND, ImmutableList.of( + new Between(new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "1").build()) .outputRowsCount(filterSelectivityX * inequalityFilterSelectivityY * inputRowCount) .symbolStats("x", symbolAssertX) .symbolStats("y", symbolAssertY); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), + new Logical(AND, ImmutableList.of( + new Between(new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(filterSelectivityX * Math.pow(inequalityFilterSelectivityY, 0.5) * inputRowCount) .symbolStats("x", symbolAssertX) @@ -439,36 +439,36 @@ public void testAndStats() double nullFilterSelectivityY = 0.5; assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), - new IsNullPredicate(new SymbolReference(DOUBLE, "y")))), + new Logical(AND, ImmutableList.of( + new Between(new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), + new IsNull(new Reference(DOUBLE, "y")))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "1").build()) .outputRowsCount(filterSelectivityX * nullFilterSelectivityY * inputRowCount) .symbolStats("x", symbolAssertX) .symbolStats("y", symbolAssert -> symbolAssert.isEqualTo(SymbolStatsEstimate.zero())); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), - new IsNullPredicate(new SymbolReference(DOUBLE, "y")))), + new Logical(AND, ImmutableList.of( + new Between(new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), + new IsNull(new Reference(DOUBLE, "y")))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(filterSelectivityX * Math.pow(nullFilterSelectivityY, 0.5) * inputRowCount) .symbolStats("x", symbolAssertX) .symbolStats("y", symbolAssert -> symbolAssert.isEqualTo(SymbolStatsEstimate.zero())); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), - new IsNullPredicate(new SymbolReference(DOUBLE, "y")))), + new Logical(AND, ImmutableList.of( + new Between(new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, -5L), DOUBLE), new Cast(new Constant(INTEGER, 5L), DOUBLE)), + new IsNull(new Reference(DOUBLE, "y")))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0").build()) .outputRowsCount(filterSelectivityX * inputRowCount) .symbolStats("x", symbolAssertX) .symbolStats("y", symbolAssert -> symbolAssert.isEqualTo(SymbolStatsEstimate.zero())); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), - new ComparisonExpression(LESS_THAN, new Cast(new Constant(INTEGER, 0L), DOUBLE), new SymbolReference(DOUBLE, "y")))), + new Logical(AND, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), + new Comparison(LESS_THAN, new Cast(new Constant(INTEGER, 0L), DOUBLE), new Reference(DOUBLE, "y")))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(100) .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -478,11 +478,11 @@ public void testAndStats() .nullsFraction(0.0)); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 2L), DOUBLE)))))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), + new Logical(OR, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 2L), DOUBLE)))))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(filterSelectivityX * Math.pow(inequalityFilterSelectivityY, 0.5) * inputRowCount) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -497,11 +497,11 @@ public void testAndStats() .nullsFraction(0.0)); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), + new Logical(OR, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(172.0) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -516,19 +516,19 @@ public void testAndStats() .nullsFraction(0.1053779069)); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of( + new Logical(AND, ImmutableList.of( + new In(new Reference(DOUBLE, "x"), ImmutableList.of( new Cast(new Constant(INTEGER, 0L), DOUBLE), new Cast(new Constant(INTEGER, 1L), DOUBLE), new Cast(new Constant(INTEGER, 2L), DOUBLE))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 2L), DOUBLE)), - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))))))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), + new Logical(AND, ImmutableList.of( + new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 1L), DOUBLE)), + new Comparison(EQUAL, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))), + new Logical(AND, ImmutableList.of( + new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 2L), DOUBLE)), + new Comparison(EQUAL, new Reference(DOUBLE, "y"), new Cast(new Constant(INTEGER, 1L), DOUBLE)))))))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(20.373798) .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -543,8 +543,8 @@ public void testAndStats() .nullsFraction(0.2300749269)); assertExpression( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(DOUBLE, "x"), new Cast(new Constant(INTEGER, 0L), DOUBLE)), new Constant(BOOLEAN, null))), Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) .outputRowsCount(filterSelectivityX * inputRowCount * 0.9) @@ -558,7 +558,7 @@ public void testAndStats() @Test public void testNotStats() { - assertExpression(new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)))) + assertExpression(new Not(new Comparison(LESS_THAN, new Reference(DOUBLE, "x"), new Constant(DOUBLE, 0.0)))) .outputRowsCount(625) // FIXME - nulls shouldn't be restored .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -568,7 +568,7 @@ public void testNotStats() .nullsFraction(0.4)) // FIXME - nulls shouldn't be restored .symbolStats(new Symbol(UNKNOWN, "y"), symbolAssert -> symbolAssert.isEqualTo(yStats)); - assertExpression(new NotExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "x")))) + assertExpression(new Not(new IsNull(new Reference(DOUBLE, "x")))) .outputRowsCount(750) .symbolStats(new Symbol(UNKNOWN, "x"), symbolAssert -> symbolAssert.averageRowSize(4.0) @@ -578,21 +578,21 @@ public void testNotStats() .nullsFraction(0)) .symbolStats(new Symbol(UNKNOWN, "y"), symbolAssert -> symbolAssert.isEqualTo(yStats)); - assertExpression(new NotExpression(new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference(DOUBLE, "x"))))) + assertExpression(new Not(new Call(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new Reference(DOUBLE, "x"))))) .outputRowsCountUnknown(); } @Test public void testIsNullFilter() { - assertExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "x"))) + assertExpression(new IsNull(new Reference(DOUBLE, "x"))) .outputRowsCount(250.0) .symbolStats(new Symbol(UNKNOWN, "x"), symbolStats -> symbolStats.distinctValuesCount(0) .emptyRange() .nullsFraction(1.0)); - assertExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "emptyRange"))) + assertExpression(new IsNull(new Reference(DOUBLE, "emptyRange"))) .outputRowsCount(1000.0) .symbolStats(new Symbol(UNKNOWN, "emptyRange"), SymbolStatsAssertion::empty); } @@ -600,7 +600,7 @@ public void testIsNullFilter() @Test public void testIsNotNullFilter() { - assertExpression(new NotExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "x")))) + assertExpression(new Not(new IsNull(new Reference(DOUBLE, "x")))) .outputRowsCount(750.0) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(40.0) @@ -608,7 +608,7 @@ public void testIsNotNullFilter() .highValue(10.0) .nullsFraction(0.0)); - assertExpression(new NotExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "emptyRange")))) + assertExpression(new Not(new IsNull(new Reference(DOUBLE, "emptyRange")))) .outputRowsCount(0.0) .symbolStats("emptyRange", SymbolStatsAssertion::empty); } @@ -617,7 +617,7 @@ public void testIsNotNullFilter() public void testBetweenOperatorFilter() { // Only right side cut - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 12.0))) + assertExpression(new Between(new Reference(DOUBLE, "x"), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 12.0))) .outputRowsCount(93.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(5.0) @@ -626,14 +626,14 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Only left side cut - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -12.0), new Constant(DOUBLE, -7.5))) + assertExpression(new Between(new Reference(DOUBLE, "x"), new Constant(DOUBLE, -12.0), new Constant(DOUBLE, -7.5))) .outputRowsCount(93.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(5.0) .lowValue(-10) .highValue(-7.5) .nullsFraction(0.0)); - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -12.0), new Constant(DOUBLE, -7.5))) + assertExpression(new Between(new Reference(DOUBLE, "x"), new Constant(DOUBLE, -12.0), new Constant(DOUBLE, -7.5))) .outputRowsCount(93.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(5.0) @@ -642,7 +642,7 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Both sides cut - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "x"), new Constant(DOUBLE, -2.5), new Constant(DOUBLE, 2.5))) + assertExpression(new Between(new Reference(DOUBLE, "x"), new Constant(DOUBLE, -2.5), new Constant(DOUBLE, 2.5))) .outputRowsCount(187.5) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(10.0) @@ -651,7 +651,7 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Both sides cut unknownRange - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 2.72), new Constant(DOUBLE, 3.14))) + assertExpression(new Between(new Reference(DOUBLE, "unknownRange"), new Constant(DOUBLE, 2.72), new Constant(DOUBLE, 3.14))) .outputRowsCount(112.5) .symbolStats("unknownRange", symbolStats -> symbolStats.distinctValuesCount(6.25) @@ -660,7 +660,7 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Left side open, cut on open side - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "leftOpen"), new Constant(DOUBLE, -10.0), new Constant(DOUBLE, 10.0))) + assertExpression(new Between(new Reference(DOUBLE, "leftOpen"), new Constant(DOUBLE, -10.0), new Constant(DOUBLE, 10.0))) .outputRowsCount(180.0) .symbolStats("leftOpen", symbolStats -> symbolStats.distinctValuesCount(10.0) @@ -669,7 +669,7 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Right side open, cut on open side - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "rightOpen"), new Constant(DOUBLE, -10.0), new Constant(DOUBLE, 10.0))) + assertExpression(new Between(new Reference(DOUBLE, "rightOpen"), new Constant(DOUBLE, -10.0), new Constant(DOUBLE, 10.0))) .outputRowsCount(180.0) .symbolStats("rightOpen", symbolStats -> symbolStats.distinctValuesCount(10.0) @@ -678,12 +678,12 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Filter all - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, 27.5), new Constant(DOUBLE, 107.0))) + assertExpression(new Between(new Reference(DOUBLE, "y"), new Constant(DOUBLE, 27.5), new Constant(DOUBLE, 107.0))) .outputRowsCount(0.0) .symbolStats("y", SymbolStatsAssertion::empty); // Filter nothing - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "y"), new Constant(DOUBLE, -100.0), new Constant(DOUBLE, 100.0))) + assertExpression(new Between(new Reference(DOUBLE, "y"), new Constant(DOUBLE, -100.0), new Constant(DOUBLE, 100.0))) .outputRowsCount(500.0) .symbolStats("y", symbolStats -> symbolStats.distinctValuesCount(20.0) @@ -692,7 +692,7 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); // Filter non exact match - assertExpression(new BetweenPredicate(new SymbolReference(DOUBLE, "z"), new Constant(DOUBLE, -100.0), new Constant(DOUBLE, 100.0))) + assertExpression(new Between(new Reference(DOUBLE, "z"), new Constant(DOUBLE, -100.0), new Constant(DOUBLE, 100.0))) .outputRowsCount(900.0) .symbolStats("z", symbolStats -> symbolStats.distinctValuesCount(5.0) @@ -702,7 +702,7 @@ public void testBetweenOperatorFilter() // Expression as value. CAST from DOUBLE to DECIMAL(7,2) // Produces row count estimate without updating symbol stats - assertExpression(new BetweenPredicate(new Cast(new SymbolReference(DOUBLE, "x"), createDecimalType(7, 2)), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("-2.50"))), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("2.50"))))) + assertExpression(new Between(new Cast(new Reference(DOUBLE, "x"), createDecimalType(7, 2)), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("-2.50"))), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("2.50"))))) .outputRowsCount(219.726563) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(xStats.getDistinctValuesCount()) @@ -710,18 +710,18 @@ public void testBetweenOperatorFilter() .highValue(xStats.getHighValue()) .nullsFraction(xStats.getNullsFraction())); - assertExpression(new InPredicate(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b"))))).equalTo(standardInputStatistics); - assertExpression(new InPredicate(new Constant(createVarcharType(1), Slices.utf8Slice("a")), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("a")), new Constant(createVarcharType(1), Slices.utf8Slice("b")), new Constant(createVarcharType(1), null)))).equalTo(standardInputStatistics); - assertExpression(new InPredicate(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c"))))).outputRowsCount(0); - assertExpression(new InPredicate(new Constant(createVarcharType(1), Slices.utf8Slice("a")), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("b")), new Constant(createVarcharType(1), Slices.utf8Slice("c")), new Constant(createVarcharType(1), null)))).outputRowsCount(0); - assertExpression(new InPredicate(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3)), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), createVarcharType(3)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3))))).equalTo(standardInputStatistics); - assertExpression(new InPredicate(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c")), createVarcharType(3)), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), createVarcharType(3)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3))))).outputRowsCount(0); + assertExpression(new In(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b"))))).equalTo(standardInputStatistics); + assertExpression(new In(new Constant(createVarcharType(1), Slices.utf8Slice("a")), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("a")), new Constant(createVarcharType(1), Slices.utf8Slice("b")), new Constant(createVarcharType(1), null)))).equalTo(standardInputStatistics); + assertExpression(new In(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c"))))).outputRowsCount(0); + assertExpression(new In(new Constant(createVarcharType(1), Slices.utf8Slice("a")), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("b")), new Constant(createVarcharType(1), Slices.utf8Slice("c")), new Constant(createVarcharType(1), null)))).outputRowsCount(0); + assertExpression(new In(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3)), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), createVarcharType(3)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3))))).equalTo(standardInputStatistics); + assertExpression(new In(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c")), createVarcharType(3)), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), createVarcharType(3)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), createVarcharType(3))))).outputRowsCount(0); } @Test public void testSymbolEqualsSameSymbolFilter() { - assertExpression(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "x"), new SymbolReference(DOUBLE, "x"))) + assertExpression(new Comparison(EQUAL, new Reference(DOUBLE, "x"), new Reference(DOUBLE, "x"))) .outputRowsCount(750) .symbolStats("x", symbolStats -> SymbolStatsEstimate.builder() @@ -736,28 +736,28 @@ public void testSymbolEqualsSameSymbolFilter() public void testInPredicateFilter() { // One value in range - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, 7.5)))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, 7.5)))) .outputRowsCount(18.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(1.0) .lowValue(7.5) .highValue(7.5) .nullsFraction(0.0)); - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -7.5)))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -7.5)))) .outputRowsCount(18.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(1.0) .lowValue(-7.5) .highValue(-7.5) .nullsFraction(0.0)); - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new ArithmeticBinaryExpression(ADD_DOUBLE, ADD, new Constant(DOUBLE, 2.0), new Constant(DOUBLE, 5.5))))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Arithmetic(ADD_DOUBLE, ADD, new Constant(DOUBLE, 2.0), new Constant(DOUBLE, 5.5))))) .outputRowsCount(18.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(1.0) .lowValue(7.5) .highValue(7.5) .nullsFraction(0.0)); - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -7.5)))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -7.5)))) .outputRowsCount(18.75) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(1.0) @@ -766,7 +766,7 @@ public void testInPredicateFilter() .nullsFraction(0.0)); // Multiple values in range - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5)))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5)))) .outputRowsCount(56.25) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(3.0) @@ -781,7 +781,7 @@ public void testInPredicateFilter() .nullsFraction(0.5)); // Multiple values some in some out of range - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 314.0)))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 314.0)))) .outputRowsCount(56.25) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(3.0) @@ -790,7 +790,7 @@ public void testInPredicateFilter() .nullsFraction(0.0)); // Multiple values some including NULL - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 314.0), new Constant(DOUBLE, null)))) + assertExpression(new In(new Reference(DOUBLE, "x"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 314.0), new Constant(DOUBLE, null)))) .outputRowsCount(56.25) .symbolStats("x", symbolStats -> symbolStats.distinctValuesCount(3.0) @@ -799,7 +799,7 @@ public void testInPredicateFilter() .nullsFraction(0.0)); // Multiple values in unknown range - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "unknownRange"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 314.0)))) + assertExpression(new In(new Reference(DOUBLE, "unknownRange"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 1.5), new Constant(DOUBLE, 2.5), new Constant(DOUBLE, 7.5), new Constant(DOUBLE, 314.0)))) .outputRowsCount(90.0) .symbolStats("unknownRange", symbolStats -> symbolStats.distinctValuesCount(5.0) @@ -808,25 +808,25 @@ public void testInPredicateFilter() .nullsFraction(0.0)); // Casted literals as value - assertExpression(new InPredicate(new SymbolReference(MEDIUM_VARCHAR_TYPE, "mediumVarchar"), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("abc")), MEDIUM_VARCHAR_TYPE)))) + assertExpression(new In(new Reference(MEDIUM_VARCHAR_TYPE, "mediumVarchar"), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("abc")), MEDIUM_VARCHAR_TYPE)))) .outputRowsCount(4) .symbolStats("mediumVarchar", symbolStats -> symbolStats.distinctValuesCount(1) .nullsFraction(0.0)); - assertExpression(new InPredicate(new SymbolReference(MEDIUM_VARCHAR_TYPE, "mediumVarchar"), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("abc")), createVarcharType(100)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("def")), createVarcharType(100))))) + assertExpression(new In(new Reference(MEDIUM_VARCHAR_TYPE, "mediumVarchar"), ImmutableList.of(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("abc")), createVarcharType(100)), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("def")), createVarcharType(100))))) .outputRowsCount(8) .symbolStats("mediumVarchar", symbolStats -> symbolStats.distinctValuesCount(2) .nullsFraction(0.0)); // No value in range - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "y"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 6.0), new Constant(DOUBLE, 31.1341), new Constant(DOUBLE, -0.000000002), new Constant(DOUBLE, 314.0)))) + assertExpression(new In(new Reference(DOUBLE, "y"), ImmutableList.of(new Constant(DOUBLE, -42.0), new Constant(DOUBLE, 6.0), new Constant(DOUBLE, 31.1341), new Constant(DOUBLE, -0.000000002), new Constant(DOUBLE, 314.0)))) .outputRowsCount(0.0) .symbolStats("y", SymbolStatsAssertion::empty); // More values in range than distinct values - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "z"), ImmutableList.of(new Constant(DOUBLE, -1.0), new Constant(DOUBLE, 3.14), new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 1.0), new Constant(DOUBLE, 2.0), new Constant(DOUBLE, 3.0), new Constant(DOUBLE, 4.0), new Constant(DOUBLE, 5.0), new Constant(DOUBLE, 6.0), new Constant(DOUBLE, 7.0), new Constant(DOUBLE, 8.0), new Constant(DOUBLE, -2.0)))) + assertExpression(new In(new Reference(DOUBLE, "z"), ImmutableList.of(new Constant(DOUBLE, -1.0), new Constant(DOUBLE, 3.14), new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 1.0), new Constant(DOUBLE, 2.0), new Constant(DOUBLE, 3.0), new Constant(DOUBLE, 4.0), new Constant(DOUBLE, 5.0), new Constant(DOUBLE, 6.0), new Constant(DOUBLE, 7.0), new Constant(DOUBLE, 8.0), new Constant(DOUBLE, -2.0)))) .outputRowsCount(900.0) .symbolStats("z", symbolStats -> symbolStats.distinctValuesCount(5.0) @@ -835,7 +835,7 @@ public void testInPredicateFilter() .nullsFraction(0.0)); // Values in weird order - assertExpression(new InPredicate(new SymbolReference(DOUBLE, "z"), ImmutableList.of(new Constant(DOUBLE, -1.0), new Constant(DOUBLE, 1.0), new Constant(DOUBLE, 0.0)))) + assertExpression(new In(new Reference(DOUBLE, "z"), ImmutableList.of(new Constant(DOUBLE, -1.0), new Constant(DOUBLE, 1.0), new Constant(DOUBLE, 0.0)))) .outputRowsCount(540.0) .symbolStats("z", symbolStats -> symbolStats.distinctValuesCount(3.0) diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsRule.java index 0d12b40b7af6..9392f7194cbd 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsRule.java @@ -15,9 +15,9 @@ package io.trino.cost; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -26,7 +26,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.type.UnknownType.UNKNOWN; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -57,7 +57,7 @@ public void tearDownClass() public void testEstimatableFilter() { tester().assertStatsFor(pb -> pb - .filter(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "i1"), new Constant(INTEGER, 5L)), + .filter(new Comparison(EQUAL, new Reference(INTEGER, "i1"), new Constant(INTEGER, 5L)), pb.values(pb.symbol("i1"), pb.symbol("i2"), pb.symbol("i3")))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) @@ -102,7 +102,7 @@ public void testEstimatableFilter() .nullsFraction(0.05))); defaultFilterTester.assertStatsFor(pb -> pb - .filter(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "i1"), new Constant(INTEGER, 5L)), + .filter(new Comparison(EQUAL, new Reference(INTEGER, "i1"), new Constant(INTEGER, 5L)), pb.values(pb.symbol("i1"), pb.symbol("i2"), pb.symbol("i3")))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) @@ -151,11 +151,11 @@ public void testEstimatableFilter() public void testUnestimatableFunction() { // can't estimate function and default filter factor is turned off - ComparisonExpression unestimatableExpression = new ComparisonExpression( + Comparison unestimatableExpression = new Comparison( EQUAL, new TestingFunctionResolution() .functionCallBuilder("sin") - .addArgument(DOUBLE, new SymbolReference(DOUBLE, "i1")) + .addArgument(DOUBLE, new Reference(DOUBLE, "i1")) .build(), new Constant(DOUBLE, 1.0)); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestJoinStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestJoinStatsRule.java index 9624dd418923..c61fecd3b3b8 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestJoinStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestJoinStatsRule.java @@ -14,7 +14,7 @@ package io.trino.cost; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -199,7 +199,7 @@ public void testStatsForInnerJoinWithTwoEquiClausesAndNonEqualityFunction() Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, DOUBLE); Symbol leftJoinColumnSymbol2 = pb.symbol(LEFT_JOIN_COLUMN_2, BIGINT); Symbol rightJoinColumnSymbol2 = pb.symbol(RIGHT_JOIN_COLUMN_2, DOUBLE); - ComparisonExpression leftJoinColumnLessThanTen = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, leftJoinColumnSymbol.toSymbolReference(), new Constant(INTEGER, 10L)); + Comparison leftJoinColumnLessThanTen = new Comparison(Comparison.Operator.LESS_THAN, leftJoinColumnSymbol.toSymbolReference(), new Constant(INTEGER, 10L)); return pb.join( INNER, pb.values(leftJoinColumnSymbol, leftJoinColumnSymbol2), diff --git a/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java index 96373346b55b..628f95c2a097 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestScalarStatsCalculator.java @@ -24,12 +24,12 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.transaction.TestingTransactionManager; import io.trino.transaction.TransactionManager; @@ -46,11 +46,11 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.TransactionBuilder.transaction; import static io.trino.type.UnknownType.UNKNOWN; @@ -139,7 +139,7 @@ public void testFunctionCall() assertCalculate( functionResolution .functionCallBuilder("length") - .addArgument(createVarcharType(2), new SymbolReference(createVarcharType(2), "x")) + .addArgument(createVarcharType(2), new Reference(createVarcharType(2), "x")) .build(), PlanNodeStatsEstimate.unknown()) .distinctValuesCountUnknown() @@ -174,8 +174,8 @@ public void testSymbolReference() .addSymbolStatistics(new Symbol(UNKNOWN, "x"), xStats) .build(); - assertCalculate(new SymbolReference(INTEGER, "x"), inputStatistics).isEqualTo(xStats); - assertCalculate(new SymbolReference(INTEGER, "y"), inputStatistics).isEqualTo(SymbolStatsEstimate.unknown()); + assertCalculate(new Reference(INTEGER, "x"), inputStatistics).isEqualTo(xStats); + assertCalculate(new Reference(INTEGER, "y"), inputStatistics).isEqualTo(SymbolStatsEstimate.unknown()); } @Test @@ -192,7 +192,7 @@ public void testCastDoubleToBigint() .build(); assertCalculate( - new Cast(new SymbolReference(BIGINT, "a"), BIGINT), + new Cast(new Reference(BIGINT, "a"), BIGINT), inputStatistics) .lowValue(2.0) .highValue(17.0) @@ -215,7 +215,7 @@ public void testCastDoubleToShortRange() .build(); assertCalculate( - new Cast(new SymbolReference(BIGINT, "a"), BIGINT), + new Cast(new Reference(BIGINT, "a"), BIGINT), inputStatistics) .lowValue(2.0) .highValue(3.0) @@ -237,7 +237,7 @@ public void testCastDoubleToShortRangeUnknownDistinctValuesCount() .build(); assertCalculate( - new Cast(new SymbolReference(BIGINT, "a"), BIGINT), + new Cast(new Reference(BIGINT, "a"), BIGINT), inputStatistics) .lowValue(2.0) .highValue(3.0) @@ -260,7 +260,7 @@ public void testCastBigintToDouble() .build(); assertCalculate( - new Cast(new SymbolReference(DOUBLE, "a"), DOUBLE), + new Cast(new Reference(DOUBLE, "a"), DOUBLE), inputStatistics) .lowValue(2.0) .highValue(10.0) @@ -273,7 +273,7 @@ public void testCastBigintToDouble() public void testCastUnknown() { assertCalculate( - new Cast(new SymbolReference(BIGINT, "a"), BIGINT), + new Cast(new Reference(BIGINT, "a"), BIGINT), PlanNodeStatsEstimate.unknown()) .lowValueUnknown() .highValueUnknown() @@ -320,26 +320,26 @@ public void testNonDivideArithmeticBinaryExpression() .setOutputRowCount(10) .build(); - assertCalculate(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), relationStats) + assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), relationStats) .distinctValuesCount(10.0) .lowValue(-3.0) .highValue(15.0) .nullsFraction(0.28) .averageRowSize(2.0); - assertCalculate(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "unknown")), relationStats) + assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "x"), new Reference(BIGINT, "unknown")), relationStats) .isEqualTo(SymbolStatsEstimate.unknown()); - assertCalculate(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "unknown"), new SymbolReference(BIGINT, "unknown")), relationStats) + assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "unknown"), new Reference(BIGINT, "unknown")), relationStats) .isEqualTo(SymbolStatsEstimate.unknown()); - assertCalculate(new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), relationStats) + assertCalculate(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), relationStats) .distinctValuesCount(10.0) .lowValue(-6.0) .highValue(12.0) .nullsFraction(0.28) .averageRowSize(2.0); - assertCalculate(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), relationStats) + assertCalculate(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), relationStats) .distinctValuesCount(10.0) .lowValue(-20.0) .highValue(50.0) @@ -363,95 +363,95 @@ public void testArithmeticBinaryWithAllNullsSymbol() .setOutputRowCount(10) .build(); - assertCalculate(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "all_null")), relationStats) + assertCalculate(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "all_null")), relationStats) + assertCalculate(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "all_null"), new SymbolReference(BIGINT, "x")), relationStats) + assertCalculate(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "all_null"), new SymbolReference(BIGINT, "x")), relationStats) + assertCalculate(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "all_null")), relationStats) + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "all_null"), new SymbolReference(BIGINT, "x")), relationStats) + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "all_null")), relationStats) + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "all_null")), relationStats) .isEqualTo(allNullStats); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "all_null"), new SymbolReference(BIGINT, "x")), relationStats) + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "all_null"), new Reference(BIGINT, "x")), relationStats) .isEqualTo(allNullStats); } @Test public void testDivideArithmeticBinaryExpression() { - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, -3, -5, -4)).lowValue(0.6).highValue(2.75); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, -3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, -3, 4, 5)).lowValue(-2.75).highValue(-0.6); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, -3, -5, -4)).lowValue(0.6).highValue(2.75); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, -3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, -3, 4, 5)).lowValue(-2.75).highValue(-0.6); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, 0, -5, -4)).lowValue(0).highValue(2.75); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, 0, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, 0, 4, 5)).lowValue(-2.75).highValue(0); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 0, -5, -4)).lowValue(0).highValue(2.75); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 0, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 0, 4, 5)).lowValue(-2.75).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, 3, -5, -4)).lowValue(-0.75).highValue(2.75); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-11, 3, 4, 5)).lowValue(-2.75).highValue(0.75); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 3, -5, -4)).lowValue(-0.75).highValue(2.75); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-11, 3, 4, 5)).lowValue(-2.75).highValue(0.75); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 3, -5, -4)).lowValue(-0.75).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 3, 4, 5)).lowValue(0).highValue(0.75); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 3, -5, -4)).lowValue(-0.75).highValue(0); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 3, 4, 5)).lowValue(0).highValue(0.75); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(3, 11, -5, -4)).lowValue(-2.75).highValue(-0.6); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(3, 11, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); - assertCalculate(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(3, 11, 4, 5)).lowValue(0.6).highValue(2.75); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(3, 11, -5, -4)).lowValue(-2.75).highValue(-0.6); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(3, 11, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(3, 11, 4, 5)).lowValue(0.6).highValue(2.75); } @Test public void testModulusArithmeticBinaryExpression() { // negative - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 0, -6, -4)).lowValue(-1).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 0, -6, -4)).lowValue(-5).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 0, -6, 4)).lowValue(-6).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 0, -6, 6)).lowValue(-6).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 0, 4, 6)).lowValue(-1).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 0, 4, 6)).lowValue(-5).highValue(0); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 0, -6, -4)).lowValue(-1).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 0, -6, -4)).lowValue(-5).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, 4)).lowValue(-6).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, -6, 6)).lowValue(-6).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 0, 4, 6)).lowValue(-1).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 0, 4, 6)).lowValue(-5).highValue(0); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); // positive - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 5, -6, -4)).lowValue(0).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 8, -6, -4)).lowValue(0).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 1, -6, 4)).lowValue(0).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 5, -6, 4)).lowValue(0).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 8, -6, 4)).lowValue(0).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 1, 4, 6)).lowValue(0).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 5, 4, 6)).lowValue(0).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(0, 8, 4, 6)).lowValue(0).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 5, -6, -4)).lowValue(0).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 8, -6, -4)).lowValue(0).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 1, -6, 4)).lowValue(0).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 5, -6, 4)).lowValue(0).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 8, -6, 4)).lowValue(0).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 1, 4, 6)).lowValue(0).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 5, 4, 6)).lowValue(0).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(0, 8, 4, 6)).lowValue(0).highValue(6); // mix - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 1, -6, -4)).lowValue(-1).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 5, -6, -4)).lowValue(-1).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 1, -6, -4)).lowValue(-5).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 5, -6, -4)).lowValue(-5).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 8, -6, -4)).lowValue(-5).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 5, -6, -4)).lowValue(-6).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 8, -6, -4)).lowValue(-6).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 1, -6, 4)).lowValue(-1).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 5, -6, 4)).lowValue(-1).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 1, -6, 4)).lowValue(-5).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 5, -6, 4)).lowValue(-5).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 8, -6, 4)).lowValue(-5).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 5, -6, 4)).lowValue(-6).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 8, -6, 4)).lowValue(-6).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 1, 4, 6)).lowValue(-1).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-1, 5, 4, 6)).lowValue(-1).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 1, 4, 6)).lowValue(-5).highValue(1); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 5, 4, 6)).lowValue(-5).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-5, 8, 4, 6)).lowValue(-5).highValue(6); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 5, 4, 6)).lowValue(-6).highValue(5); - assertCalculate(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), xyStats(-8, 8, 4, 6)).lowValue(-6).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 1, -6, -4)).lowValue(-1).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 5, -6, -4)).lowValue(-1).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 1, -6, -4)).lowValue(-5).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 5, -6, -4)).lowValue(-5).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 8, -6, -4)).lowValue(-5).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 5, -6, -4)).lowValue(-6).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 8, -6, -4)).lowValue(-6).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 1, -6, 4)).lowValue(-1).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 5, -6, 4)).lowValue(-1).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 1, -6, 4)).lowValue(-5).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 5, -6, 4)).lowValue(-5).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 8, -6, 4)).lowValue(-5).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 5, -6, 4)).lowValue(-6).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 8, -6, 4)).lowValue(-6).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 1, 4, 6)).lowValue(-1).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-1, 5, 4, 6)).lowValue(-1).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 1, 4, 6)).lowValue(-5).highValue(1); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 5, 4, 6)).lowValue(-5).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-5, 8, 4, 6)).lowValue(-5).highValue(6); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 5, 4, 6)).lowValue(-6).highValue(5); + assertCalculate(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), xyStats(-8, 8, 4, 6)).lowValue(-6).highValue(6); } private PlanNodeStatsEstimate xyStats(double lowX, double highX, double lowY, double highY) @@ -489,14 +489,14 @@ public void testCoalesceExpression() .setOutputRowCount(10) .build(); - assertCalculate(new CoalesceExpression(new SymbolReference(INTEGER, "x"), new SymbolReference(INTEGER, "y")), relationStats) + assertCalculate(new Coalesce(new Reference(INTEGER, "x"), new Reference(INTEGER, "y")), relationStats) .distinctValuesCount(5) .lowValue(-2) .highValue(10) .nullsFraction(0.02) .averageRowSize(2.0); - assertCalculate(new CoalesceExpression(new SymbolReference(INTEGER, "y"), new SymbolReference(INTEGER, "x")), relationStats) + assertCalculate(new Coalesce(new Reference(INTEGER, "y"), new Reference(INTEGER, "x")), relationStats) .distinctValuesCount(5) .lowValue(-2) .highValue(10) diff --git a/core/trino-main/src/test/java/io/trino/cost/TestSimpleFilterProjectSemiJoinStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestSimpleFilterProjectSemiJoinStatsRule.java index 87ae5671c2a5..9e33dcd33af5 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestSimpleFilterProjectSemiJoinStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestSimpleFilterProjectSemiJoinStatsRule.java @@ -14,11 +14,11 @@ package io.trino.cost; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.PlanNodeId; @@ -29,8 +29,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.type.UnknownType.UNKNOWN; public class TestSimpleFilterProjectSemiJoinStatsRule @@ -128,7 +128,7 @@ public void testFilterPositiveNarrowingProjectSemiJoin() Symbol c = pb.symbol("c", BIGINT); Symbol semiJoinOutput = pb.symbol("sjo", BOOLEAN); return pb.filter( - new SymbolReference(BOOLEAN, "sjo"), + new Reference(BOOLEAN, "sjo"), pb.project(Assignments.identity(semiJoinOutput, a), pb.semiJoin( pb.values(LEFT_SOURCE_ID, a, b), @@ -167,7 +167,7 @@ public void testFilterPositivePlusExtraConjunctSemiJoin() Symbol c = pb.symbol("c", BIGINT); Symbol semiJoinOutput = pb.symbol("sjo", BOOLEAN); return pb.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "sjo"), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 8L)))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "sjo"), new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 8L)))), pb.semiJoin( pb.values(LEFT_SOURCE_ID, a, b), pb.values(RIGHT_SOURCE_ID, c), @@ -205,7 +205,7 @@ public void testFilterNegativeSemiJoin() Symbol c = pb.symbol("c", BIGINT); Symbol semiJoinOutput = pb.symbol("sjo", BOOLEAN); return pb.filter( - new NotExpression(new SymbolReference(BOOLEAN, "sjo")), + new Not(new Reference(BOOLEAN, "sjo")), pb.semiJoin( pb.values(LEFT_SOURCE_ID, a, b), pb.values(RIGHT_SOURCE_ID, c), diff --git a/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java b/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java index c27d9a65ab6d..923ded8b1e46 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestValuesNodeStats.java @@ -19,7 +19,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import org.junit.jupiter.api.Test; @@ -29,8 +29,8 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; import static io.trino.type.UnknownType.UNKNOWN; public class TestValuesNodeStats @@ -47,7 +47,7 @@ public void testStatsForValuesNode() tester().assertStatsFor(pb -> pb .values(ImmutableList.of(pb.symbol("a", BIGINT), pb.symbol("b", DOUBLE)), ImmutableList.of( - ImmutableList.of(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Constant(BIGINT, 3L), new Constant(BIGINT, 3L)), new Constant(DOUBLE, 13.5e0)), + ImmutableList.of(new Arithmetic(ADD_BIGINT, ADD, new Constant(BIGINT, 3L), new Constant(BIGINT, 3L)), new Constant(DOUBLE, 13.5e0)), ImmutableList.of(new Constant(BIGINT, 55L), new Constant(UNKNOWN, null)), ImmutableList.of(new Constant(BIGINT, 6L), new Constant(DOUBLE, 13.5e0))))) .check(outputStats -> outputStats.equalTo( @@ -96,7 +96,7 @@ public void testDivisionByZero() { tester().assertStatsFor(pb -> pb .values(ImmutableList.of(pb.symbol("a", BIGINT)), - ImmutableList.of(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)))))) + ImmutableList.of(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)))))) .check(outputStats -> outputStats.equalTo(unknown())); } @@ -111,7 +111,7 @@ public void testStatsForValuesNodeWithJustNulls() tester().assertStatsFor(pb -> pb .values(ImmutableList.of(pb.symbol("a", BIGINT)), ImmutableList.of( - ImmutableList.of(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, 3L), new Constant(UNKNOWN, null)))))) + ImmutableList.of(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 3L), new Constant(UNKNOWN, null)))))) .check(outputStats -> outputStats.equalTo(nullAStats)); tester().assertStatsFor(pb -> pb diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java index d6c6435a7f56..5f3f2858a0a4 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java @@ -20,7 +20,7 @@ import io.airlift.slice.Slices; import io.trino.cost.StatsAndCosts; import io.trino.operator.RetryPolicy; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Row; import io.trino.sql.planner.Partitioning; @@ -334,7 +334,7 @@ private static SpatialJoinNode spatialJoin(String id, PlanNode left, PlanNode ri left, right, left.getOutputSymbols(), - BooleanLiteral.TRUE_LITERAL, + Booleans.TRUE, Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java b/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java index 5a78beda338d..eafe51ef59c9 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java @@ -26,8 +26,8 @@ import io.trino.sql.analyzer.TypeSignatureProvider; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.gen.PageFunctionCompiler; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.planner.TestingPlannerContext; import io.trino.testing.QueryRunner; import io.trino.transaction.TransactionManager; @@ -213,9 +213,9 @@ public TestingFunctionCallBuilder setArguments(List types, List createInputPages(List types) private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new Cast(new SymbolReference(VARCHAR, "varchar0"), BIGINT), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Cast(new Reference(VARCHAR, "varchar0"), BIGINT), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); } if (type == BIGINT) { - return rowExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(INTEGER, "bigint0"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -231,14 +231,14 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L)))); + builder.add(rowExpression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L)))); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression(new FunctionCall(CONCAT, ImmutableList.of(new SymbolReference(VARCHAR, "varchar" + i), new Constant(VARCHAR, Slices.utf8Slice("foo")))))); + builder.add(rowExpression(new Call(CONCAT, ImmutableList.of(new Reference(VARCHAR, "varchar" + i), new Constant(VARCHAR, Slices.utf8Slice("foo")))))); } } return builder.build(); diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java index 5658ed15b42b..7ede8742a832 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java @@ -24,22 +24,22 @@ import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NullIfExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.Symbol; import io.trino.sql.relational.RowExpression; @@ -64,14 +64,14 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static org.assertj.core.api.Assertions.assertThat; @@ -100,34 +100,34 @@ public void testEagerLoading() RowExpressionBuilder builder = RowExpressionBuilder.create() .addSymbol("bigint0", BIGINT) .addSymbol("bigint1", BIGINT); - verifyEagerlyLoadedColumns(builder.buildExpression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "bigint0"), new Constant(INTEGER, 5L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Cast(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "bigint0"), new Constant(INTEGER, 10L)), INTEGER)), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new CoalesceExpression(new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(BIGINT, "bigint0"), new Constant(BIGINT, 2L)), new SymbolReference(BIGINT, "bigint0"))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new InPredicate(new SymbolReference(BIGINT, "bigint0"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L), new Constant(BIGINT, 3L)))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "bigint0"), new Constant(INTEGER, 0L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "bigint0"), new Constant(BIGINT, 1L)), new Constant(BIGINT, 0L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new BetweenPredicate(new SymbolReference(BIGINT, "bigint0"), new Constant(INTEGER, 1L), new Constant(INTEGER, 10L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "bigint0"), new Constant(INTEGER, 0L)), new SymbolReference(BIGINT, "bigint0"))), Optional.of(new Constant(BIGINT, null)))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new SimpleCaseExpression(new SymbolReference(BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BIGINT, 1L), new Constant(BIGINT, 1L))), Optional.of(new ArithmeticNegation(new SymbolReference(BIGINT, "bigint0"))))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new CoalesceExpression(new Constant(BIGINT, 0L), new SymbolReference(BIGINT, "bigint0")), new SymbolReference(BIGINT, "bigint0"))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 5L))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Cast(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 10L)), INTEGER)), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "bigint0"))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new In(new Reference(BIGINT, "bigint0"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L), new Constant(BIGINT, 3L)))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(EQUAL, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 1L)), new Constant(BIGINT, 0L))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Between(new Reference(BIGINT, "bigint0"), new Constant(INTEGER, 1L), new Constant(INTEGER, 10L))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L)), new Reference(BIGINT, "bigint0"))), Optional.of(new Constant(BIGINT, null)))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Switch(new Reference(BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BIGINT, 1L), new Constant(BIGINT, 1L))), Optional.of(new Negation(new Reference(BIGINT, "bigint0"))))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Arithmetic(ADD_BIGINT, ADD, new Coalesce(new Constant(BIGINT, 0L), new Reference(BIGINT, "bigint0")), new Reference(BIGINT, "bigint0"))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "bigint0"), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 2L), new SymbolReference(BIGINT, "bigint1")))), 2); - verifyEagerlyLoadedColumns(builder.buildExpression(new NullIfExpression(new SymbolReference(BIGINT, "bigint0"), new SymbolReference(BIGINT, "bigint1"))), 2); - verifyEagerlyLoadedColumns(builder.buildExpression(new CoalesceExpression(new FunctionCall(CEIL, ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_BIGINT, DIVIDE, new SymbolReference(BIGINT, "bigint0"), new SymbolReference(BIGINT, "bigint1")))), new Constant(BIGINT, 0L))), 2); - verifyEagerlyLoadedColumns(builder.buildExpression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "bigint0"), new SymbolReference(BIGINT, "bigint1")), new Constant(INTEGER, 1L))), Optional.of(new Constant(INTEGER, 0L)))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint0"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 2L), new Reference(BIGINT, "bigint1")))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new NullIf(new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1"))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(CEIL, ImmutableList.of(new Arithmetic(DIVIDE_BIGINT, DIVIDE, new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1")))), new Constant(BIGINT, 0L))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1")), new Constant(INTEGER, 1L))), Optional.of(new Constant(INTEGER, 0L)))), 2); verifyEagerlyLoadedColumns( - builder.buildExpression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new SymbolReference(BIGINT, "bigint1"))), Optional.of(new Constant(BIGINT, 0L)))), 2, ImmutableSet.of(0)); - verifyEagerlyLoadedColumns(builder.buildExpression(new CoalesceExpression(new FunctionCall(ROUND, ImmutableList.of(new SymbolReference(BIGINT, "bigint0"))), new SymbolReference(BIGINT, "bigint1"))), 2, ImmutableSet.of(0)); - verifyEagerlyLoadedColumns(builder.buildExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "bigint1"), new Constant(BIGINT, 0L))))), 2, ImmutableSet.of(0)); - verifyEagerlyLoadedColumns(builder.buildExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "bigint1"), new Constant(BIGINT, 0L))))), 2, ImmutableSet.of(0)); - verifyEagerlyLoadedColumns(builder.buildExpression(new BetweenPredicate(new SymbolReference(BIGINT, "bigint0"), new Constant(BIGINT, 0L), new SymbolReference(BIGINT, "bigint1"))), 2, ImmutableSet.of(0)); + builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Reference(BIGINT, "bigint1"))), Optional.of(new Constant(BIGINT, 0L)))), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(ROUND, ImmutableList.of(new Reference(BIGINT, "bigint0"))), new Reference(BIGINT, "bigint1"))), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression(new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint1"), new Constant(BIGINT, 0L))))), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression(new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint1"), new Constant(BIGINT, 0L))))), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression(new Between(new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L), new Reference(BIGINT, "bigint1"))), 2, ImmutableSet.of(0)); builder = RowExpressionBuilder.create() .addSymbol("array_bigint0", new ArrayType(BIGINT)) .addSymbol("array_bigint1", new ArrayType(BIGINT)); - verifyEagerlyLoadedColumns(builder.buildExpression(new FunctionCall(TRANSFORM, ImmutableList.of(new SymbolReference(new ArrayType(BIGINT), "array_bigint0"), new LambdaExpression(ImmutableList.of(new Symbol(INTEGER, "x")), new Constant(INTEGER, 1L))))), 1, ImmutableSet.of()); - verifyEagerlyLoadedColumns(builder.buildExpression(new FunctionCall(TRANSFORM, ImmutableList.of(new SymbolReference(new ArrayType(BIGINT), "array_bigint0"), new LambdaExpression(ImmutableList.of(new Symbol(INTEGER, "x")), new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 2L), new SymbolReference(INTEGER, "x")))))), 1, ImmutableSet.of()); - verifyEagerlyLoadedColumns(builder.buildExpression(new FunctionCall(ZIP_WITH, ImmutableList.of(new SymbolReference(new ArrayType(BIGINT), "array_bigint0"), new SymbolReference(new ArrayType(BIGINT), "array_bigint1"), new LambdaExpression(ImmutableList.of(new Symbol(INTEGER, "x"), new Symbol(INTEGER, "y")), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 2L), new SymbolReference(BIGINT, "x")))))), 2, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(INTEGER, "x")), new Constant(INTEGER, 1L))))), 1, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(INTEGER, "x")), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 2L), new Reference(INTEGER, "x")))))), 1, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ZIP_WITH, ImmutableList.of(new Reference(new ArrayType(BIGINT), "array_bigint0"), new Reference(new ArrayType(BIGINT), "array_bigint1"), new Lambda(ImmutableList.of(new Symbol(INTEGER, "x"), new Symbol(INTEGER, "y")), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 2L), new Reference(BIGINT, "x")))))), 2, ImmutableSet.of()); } private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount) diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index 0df40ddd5786..de746c682618 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -75,7 +75,7 @@ import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.SymbolKeyDeserializer; @@ -230,8 +230,8 @@ public void testDynamicFilters() SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol symbol1 = symbolAllocator.newSymbol("DF_SYMBOL1", BIGINT); Symbol symbol2 = symbolAllocator.newSymbol("DF_SYMBOL2", BIGINT); - SymbolReference df1 = symbol1.toSymbolReference(); - SymbolReference df2 = symbol2.toSymbolReference(); + Reference df1 = symbol1.toSymbolReference(); + Reference df2 = symbol2.toSymbolReference(); ColumnHandle handle1 = new TestingColumnHandle("column1"); ColumnHandle handle2 = new TestingColumnHandle("column2"); QueryId queryId = new QueryId("test"); @@ -310,8 +310,8 @@ public void testOutboundDynamicFilters() SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol symbol1 = symbolAllocator.newSymbol("DF_SYMBOL1", BIGINT); Symbol symbol2 = symbolAllocator.newSymbol("DF_SYMBOL2", BIGINT); - SymbolReference df1 = symbol1.toSymbolReference(); - SymbolReference df2 = symbol2.toSymbolReference(); + Reference df1 = symbol1.toSymbolReference(); + Reference df2 = symbol2.toSymbolReference(); ColumnHandle handle1 = new TestingColumnHandle("column1"); ColumnHandle handle2 = new TestingColumnHandle("column2"); QueryId queryId = new QueryId("test"); diff --git a/core/trino-main/src/test/java/io/trino/sql/BenchmarkExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/BenchmarkExpressionInterpreter.java index f0e6025341b5..883fc5e98a94 100644 --- a/core/trino-main/src/test/java/io/trino/sql/BenchmarkExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/BenchmarkExpressionInterpreter.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Reference; import org.junit.jupiter.api.Test; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -64,7 +64,7 @@ public static class BenchmarkData public void setup() { expressions = ImmutableList.of( - new InPredicate(new SymbolReference(INTEGER, "bound_value"), IntStream.range(0, inValuesCount).mapToObj(i -> new Constant(INTEGER, (long) i)) + new In(new Reference(INTEGER, "bound_value"), IntStream.range(0, inValuesCount).mapToObj(i -> new Constant(INTEGER, (long) i)) .collect(Collectors.toList()))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index 0fa2358a1b8d..4386103f5bf9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -19,25 +19,25 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.Symbol; @@ -60,17 +60,17 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ExpressionTestUtils.assertExpressionEquals; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; import static io.trino.sql.ir.IrExpressions.ifExpression; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; import static io.trino.type.UnknownType.UNKNOWN; @@ -108,29 +108,29 @@ public class TestExpressionInterpreter public void testAnd() { assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(TRUE_LITERAL, FALSE_LITERAL)), - FALSE_LITERAL); + new Logical(AND, ImmutableList.of(TRUE, FALSE)), + FALSE); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(FALSE_LITERAL, TRUE_LITERAL)), - FALSE_LITERAL); + new Logical(AND, ImmutableList.of(FALSE, TRUE)), + FALSE); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(FALSE_LITERAL, FALSE_LITERAL)), - FALSE_LITERAL); + new Logical(AND, ImmutableList.of(FALSE, FALSE)), + FALSE); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(TRUE_LITERAL, new Constant(UNKNOWN, null))), + new Logical(AND, ImmutableList.of(TRUE, new Constant(UNKNOWN, null))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(FALSE_LITERAL, new Constant(UNKNOWN, null))), - FALSE_LITERAL); + new Logical(AND, ImmutableList.of(FALSE, new Constant(UNKNOWN, null))), + FALSE); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(new Constant(UNKNOWN, null), TRUE_LITERAL)), + new Logical(AND, ImmutableList.of(new Constant(UNKNOWN, null), TRUE)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(new Constant(UNKNOWN, null), FALSE_LITERAL)), - FALSE_LITERAL); + new Logical(AND, ImmutableList.of(new Constant(UNKNOWN, null), FALSE)), + FALSE); assertOptimizedEquals( - new LogicalExpression(AND, ImmutableList.of(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), + new Logical(AND, ImmutableList.of(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), new Constant(UNKNOWN, null)); } @@ -138,33 +138,33 @@ public void testAnd() public void testOr() { assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(TRUE_LITERAL, TRUE_LITERAL)), - TRUE_LITERAL); + new Logical(OR, ImmutableList.of(TRUE, TRUE)), + TRUE); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(TRUE_LITERAL, FALSE_LITERAL)), - TRUE_LITERAL); + new Logical(OR, ImmutableList.of(TRUE, FALSE)), + TRUE); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(FALSE_LITERAL, TRUE_LITERAL)), - TRUE_LITERAL); + new Logical(OR, ImmutableList.of(FALSE, TRUE)), + TRUE); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(FALSE_LITERAL, FALSE_LITERAL)), - FALSE_LITERAL); + new Logical(OR, ImmutableList.of(FALSE, FALSE)), + FALSE); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(TRUE_LITERAL, new Constant(UNKNOWN, null))), - TRUE_LITERAL); + new Logical(OR, ImmutableList.of(TRUE, new Constant(UNKNOWN, null))), + TRUE); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(new Constant(UNKNOWN, null), TRUE_LITERAL)), - TRUE_LITERAL); + new Logical(OR, ImmutableList.of(new Constant(UNKNOWN, null), TRUE)), + TRUE); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), + new Logical(OR, ImmutableList.of(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(FALSE_LITERAL, new Constant(UNKNOWN, null))), + new Logical(OR, ImmutableList.of(FALSE, new Constant(UNKNOWN, null))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new LogicalExpression(OR, ImmutableList.of(new Constant(UNKNOWN, null), FALSE_LITERAL)), + new Logical(OR, ImmutableList.of(new Constant(UNKNOWN, null), FALSE)), new Constant(UNKNOWN, null)); } @@ -172,266 +172,266 @@ public void testOr() public void testComparison() { assertOptimizedEquals( - new ComparisonExpression(EQUAL, new Constant(UNKNOWN, null), new Constant(UNKNOWN, null)), + new Comparison(EQUAL, new Constant(UNKNOWN, null), new Constant(UNKNOWN, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new ComparisonExpression(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("b"))), - FALSE_LITERAL); + new Comparison(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("b"))), + FALSE); assertOptimizedEquals( - new ComparisonExpression(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("a"))), - TRUE_LITERAL); + new Comparison(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("a"))), + TRUE); assertOptimizedEquals( - new ComparisonExpression(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(UNKNOWN, null)), + new Comparison(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(UNKNOWN, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new ComparisonExpression(EQUAL, new Constant(UNKNOWN, null), new Constant(VARCHAR, Slices.utf8Slice("a"))), + new Comparison(EQUAL, new Constant(UNKNOWN, null), new Constant(VARCHAR, Slices.utf8Slice("a"))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), - TRUE_LITERAL); + new Comparison(EQUAL, new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), + TRUE); assertOptimizedEquals( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 1L)), - FALSE_LITERAL); + new Comparison(EQUAL, new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1L)), + FALSE); } @Test public void testIsDistinctFrom() { assertOptimizedEquals( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(UNKNOWN, null), new Constant(UNKNOWN, null)), - FALSE_LITERAL); + new Comparison(IS_DISTINCT_FROM, new Constant(UNKNOWN, null), new Constant(UNKNOWN, null)), + FALSE); assertOptimizedEquals( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), - TRUE_LITERAL); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + TRUE); assertOptimizedEquals( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 3L), new Constant(INTEGER, 3L)), - FALSE_LITERAL); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 3L), new Constant(INTEGER, 3L)), + FALSE); assertOptimizedEquals( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 3L), new Constant(UNKNOWN, null)), - TRUE_LITERAL); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 3L), new Constant(UNKNOWN, null)), + TRUE); assertOptimizedEquals( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(UNKNOWN, null), new Constant(INTEGER, 3L)), - TRUE_LITERAL); + new Comparison(IS_DISTINCT_FROM, new Constant(UNKNOWN, null), new Constant(INTEGER, 3L)), + TRUE); assertOptimizedMatches( - new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), - new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); + new Comparison(IS_DISTINCT_FROM, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), + new Comparison(IS_DISTINCT_FROM, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); assertOptimizedMatches( - new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(UNKNOWN, "unbound_value"), new Constant(UNKNOWN, null)), - new NotExpression(new IsNullPredicate(new SymbolReference(INTEGER, "unbound_value")))); + new Comparison(IS_DISTINCT_FROM, new Reference(UNKNOWN, "unbound_value"), new Constant(UNKNOWN, null)), + new Not(new IsNull(new Reference(INTEGER, "unbound_value")))); assertOptimizedMatches( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(UNKNOWN, null), new SymbolReference(INTEGER, "unbound_value")), - new NotExpression(new IsNullPredicate(new SymbolReference(INTEGER, "unbound_value")))); + new Comparison(IS_DISTINCT_FROM, new Constant(UNKNOWN, null), new Reference(INTEGER, "unbound_value")), + new Not(new IsNull(new Reference(INTEGER, "unbound_value")))); } @Test public void testIsNull() { assertOptimizedEquals( - new IsNullPredicate(new Constant(UNKNOWN, null)), - TRUE_LITERAL); + new IsNull(new Constant(UNKNOWN, null)), + TRUE); assertOptimizedEquals( - new IsNullPredicate(new Constant(INTEGER, 1L)), - FALSE_LITERAL); + new IsNull(new Constant(INTEGER, 1L)), + FALSE); assertOptimizedEquals( - new IsNullPredicate(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, null), new Constant(INTEGER, 1L))), - TRUE_LITERAL); + new IsNull(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, null), new Constant(INTEGER, 1L))), + TRUE); } @Test public void testIsNotNull() { assertOptimizedEquals( - new NotExpression(new IsNullPredicate(new Constant(UNKNOWN, null))), - FALSE_LITERAL); + new Not(new IsNull(new Constant(UNKNOWN, null))), + FALSE); assertOptimizedEquals( - new NotExpression(new IsNullPredicate(new Constant(INTEGER, 1L))), - TRUE_LITERAL); + new Not(new IsNull(new Constant(INTEGER, 1L))), + TRUE); assertOptimizedEquals( - new NotExpression(new IsNullPredicate(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), - FALSE_LITERAL); + new Not(new IsNull(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), + FALSE); } @Test public void testNullIf() { assertOptimizedEquals( - new NullIfExpression(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("a"))), + new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("a"))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new NullIfExpression(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("b"))), + new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("b"))), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( - new NullIfExpression(new Constant(UNKNOWN, null), new Constant(VARCHAR, Slices.utf8Slice("b"))), + new NullIf(new Constant(UNKNOWN, null), new Constant(VARCHAR, Slices.utf8Slice("b"))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new NullIfExpression(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(UNKNOWN, null)), + new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(UNKNOWN, null)), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( - new NullIfExpression(new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), - new NullIfExpression(new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); + new NullIf(new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), + new NullIf(new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); } @Test public void testNegative() { assertOptimizedEquals( - new ArithmeticNegation(new Constant(INTEGER, 1L)), + new Negation(new Constant(INTEGER, 1L)), new Constant(INTEGER, -1L)); assertOptimizedEquals( - new ArithmeticNegation(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))), - new ArithmeticNegation(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)))); + new Negation(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))), + new Negation(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)))); } @Test public void testNot() { assertOptimizedEquals( - new NotExpression(TRUE_LITERAL), - FALSE_LITERAL); + new Not(TRUE), + FALSE); assertOptimizedEquals( - new NotExpression(FALSE_LITERAL), - TRUE_LITERAL); + new Not(FALSE), + TRUE); assertOptimizedEquals( - new NotExpression(new Constant(UNKNOWN, null)), + new Not(new Constant(UNKNOWN, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))), - new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)))); + new Not(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))), + new Not(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)))); } @Test public void testFunctionCall() { assertOptimizedEquals( - new FunctionCall(ABS, ImmutableList.of(new Constant(INTEGER, 5L))), + new Call(ABS, ImmutableList.of(new Constant(INTEGER, 5L))), new Constant(INTEGER, 5L)); assertOptimizedEquals( - new FunctionCall(ABS, ImmutableList.of(new SymbolReference(INTEGER, "unbound_value"))), - new FunctionCall(ABS, ImmutableList.of(new SymbolReference(INTEGER, "unbound_value")))); + new Call(ABS, ImmutableList.of(new Reference(INTEGER, "unbound_value"))), + new Call(ABS, ImmutableList.of(new Reference(INTEGER, "unbound_value")))); } @Test public void testNonDeterministicFunctionCall() { assertOptimizedEquals( - new FunctionCall(RANDOM, ImmutableList.of()), - new FunctionCall(RANDOM, ImmutableList.of())); + new Call(RANDOM, ImmutableList.of()), + new Call(RANDOM, ImmutableList.of())); } @Test public void testBetween() { assertOptimizedEquals( - new BetweenPredicate(new Constant(INTEGER, 3L), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), - TRUE_LITERAL); + new Between(new Constant(INTEGER, 3L), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), + TRUE); assertOptimizedEquals( - new BetweenPredicate(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), - FALSE_LITERAL); + new Between(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + FALSE); assertOptimizedEquals( - new BetweenPredicate(new Constant(UNKNOWN, null), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), + new Between(new Constant(UNKNOWN, null), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new BetweenPredicate(new Constant(INTEGER, 3L), new Constant(UNKNOWN, null), new Constant(INTEGER, 4L)), + new Between(new Constant(INTEGER, 3L), new Constant(UNKNOWN, null), new Constant(INTEGER, 4L)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new BetweenPredicate(new Constant(INTEGER, 3L), new Constant(INTEGER, 2L), new Constant(UNKNOWN, null)), + new Between(new Constant(INTEGER, 3L), new Constant(INTEGER, 2L), new Constant(UNKNOWN, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new BetweenPredicate(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(UNKNOWN, null)), - FALSE_LITERAL); + new Between(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(UNKNOWN, null)), + FALSE); assertOptimizedEquals( - new BetweenPredicate(new Constant(INTEGER, 8L), new Constant(UNKNOWN, null), new Constant(INTEGER, 6L)), - FALSE_LITERAL); + new Between(new Constant(INTEGER, 8L), new Constant(UNKNOWN, null), new Constant(INTEGER, 6L)), + FALSE); assertOptimizedEquals( - new BetweenPredicate(new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 1000L), new Constant(INTEGER, 2000L)), - TRUE_LITERAL); + new Between(new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1000L), new Constant(INTEGER, 2000L)), + TRUE); assertOptimizedEquals( - new BetweenPredicate(new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), - FALSE_LITERAL); + new Between(new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + FALSE); } @Test public void testIn() { assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - TRUE_LITERAL); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + TRUE); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 9L), new Constant(INTEGER, 5L))), - FALSE_LITERAL); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 9L), new Constant(INTEGER, 5L))), + FALSE); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - TRUE_LITERAL); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + TRUE); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new InPredicate(new SymbolReference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 1234L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - TRUE_LITERAL); + new In(new Reference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 1234L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + TRUE); assertOptimizedEquals( - new InPredicate(new SymbolReference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - FALSE_LITERAL); + new In(new Reference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + FALSE); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 1234L), ImmutableList.of(new Constant(INTEGER, 2L), new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - TRUE_LITERAL); + new In(new Constant(INTEGER, 1234L), ImmutableList.of(new Constant(INTEGER, 2L), new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + TRUE); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 99L), ImmutableList.of(new Constant(INTEGER, 2L), new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - FALSE_LITERAL); + new In(new Constant(INTEGER, 99L), ImmutableList.of(new Constant(INTEGER, 2L), new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + FALSE); assertOptimizedEquals( - new InPredicate(new SymbolReference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - TRUE_LITERAL); + new In(new Reference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + TRUE); assertOptimizedEquals( - new InPredicate(new SymbolReference(INTEGER, "unbound_value"), ImmutableList.of(new Constant(INTEGER, 1L))), - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); + new In(new Reference(INTEGER, "unbound_value"), ImmutableList.of(new Constant(INTEGER, 1L))), + new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new InPredicate(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new InPredicate(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), + new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); assertOptimizedEquals( - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), - new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); - assertTrinoExceptionThrownBy(() -> evaluate(new InPredicate(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))) + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))), + new In(new Constant(INTEGER, 3L), ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L))))); + assertTrinoExceptionThrownBy(() -> evaluate(new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 5L), new Constant(INTEGER, 0L)))))) .hasErrorCode(DIVISION_BY_ZERO); assertOptimizedEquals( - new InPredicate(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - new InPredicate(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L)))); + new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), + new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L)))); assertOptimizedEquals( - new InPredicate(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), - new InPredicate(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))); + new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), + new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))); assertOptimizedEquals( - new InPredicate(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L))), - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))); + new In(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L))), + new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))); } @Test public void testCastOptimization() { assertOptimizedEquals( - new Cast(new SymbolReference(INTEGER, "bound_value"), VARCHAR), + new Cast(new Reference(INTEGER, "bound_value"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("1234"))); assertOptimizedMatches( - new Cast(new SymbolReference(INTEGER, "unbound_value"), INTEGER), - new SymbolReference(INTEGER, "unbound_value")); + new Cast(new Reference(INTEGER, "unbound_value"), INTEGER), + new Reference(INTEGER, "unbound_value")); } @Test @@ -455,94 +455,94 @@ public void testTryCast() public void testSearchCase() { assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(TRUE_LITERAL, new Constant(INTEGER, 33L))), + new Case(ImmutableList.of( + new WhenClause(TRUE, new Constant(INTEGER, 33L))), Optional.empty()), new Constant(INTEGER, 33L)); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new Constant(INTEGER, 1L))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Constant(INTEGER, 1L))), Optional.of(new Constant(INTEGER, 33L))), new Constant(INTEGER, 33L)); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), new Constant(INTEGER, 33L))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), new Constant(INTEGER, 33L))), Optional.empty()), new Constant(INTEGER, 33L)); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(TRUE_LITERAL, new SymbolReference(INTEGER, "bound_value"))), + new Case(ImmutableList.of( + new WhenClause(TRUE, new Reference(INTEGER, "bound_value"))), Optional.empty()), new Constant(INTEGER, 1234L)); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new Constant(INTEGER, 1L))), - Optional.of(new SymbolReference(INTEGER, "bound_value"))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Constant(INTEGER, 1L))), + Optional.of(new Reference(INTEGER, "bound_value"))), new Constant(INTEGER, 1234L)); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), new Constant(INTEGER, 33L))), - Optional.of(new SymbolReference(INTEGER, "unbound_value"))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), new Constant(INTEGER, 33L))), + Optional.of(new Reference(INTEGER, "unbound_value"))), new Constant(INTEGER, 33L)); assertOptimizedMatches( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), Optional.empty()), - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), Optional.empty())); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(TRUE_LITERAL, new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("b")))), + new Case(ImmutableList.of( + new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(TRUE_LITERAL, new Constant(VARCHAR, Slices.utf8Slice("b")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("b"))))); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE_LITERAL, new Constant(VARCHAR, Slices.utf8Slice("b")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c"))))); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), - new WhenClause(FALSE_LITERAL, new Constant(VARCHAR, Slices.utf8Slice("b")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), + new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), Optional.empty()), - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), Optional.empty())); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(TRUE_LITERAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new WhenClause(FALSE_LITERAL, new Constant(INTEGER, 1L))), + new Case(ImmutableList.of( + new WhenClause(TRUE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new WhenClause(FALSE, new Constant(INTEGER, 1L))), Optional.empty()), - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); assertOptimizedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new Constant(INTEGER, 1L)), new WhenClause(FALSE_LITERAL, new Constant(INTEGER, 2L))), - Optional.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + new Case(ImmutableList.of( + new WhenClause(FALSE, new Constant(INTEGER, 1L)), new WhenClause(FALSE, new Constant(INTEGER, 2L))), + Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); assertEvaluatedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new WhenClause(TRUE_LITERAL, new Constant(INTEGER, 1L))), - Optional.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new WhenClause(TRUE, new Constant(INTEGER, 1L))), + Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(TRUE_LITERAL, new Constant(INTEGER, 1L)), new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), - Optional.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Case(ImmutableList.of( + new WhenClause(TRUE, new Constant(INTEGER, 1L)), new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); } @@ -550,7 +550,7 @@ public void testSearchCase() public void testSimpleCase() { assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 33L)), @@ -559,141 +559,141 @@ public void testSimpleCase() new Constant(INTEGER, 33L)); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(BOOLEAN, null), ImmutableList.of( - new WhenClause(TRUE_LITERAL, new Constant(INTEGER, 33L))), + new WhenClause(TRUE, new Constant(INTEGER, 33L))), Optional.empty()), new Constant(UNKNOWN, null)); - for (SimpleCaseExpression simpleCaseExpression : Arrays.asList(new SimpleCaseExpression( + for (Switch aSwitch : Arrays.asList(new Switch( new Constant(BOOLEAN, null), ImmutableList.of( - new WhenClause(TRUE_LITERAL, new Constant(INTEGER, 33L))), + new WhenClause(TRUE, new Constant(INTEGER, 33L))), Optional.of(new Constant(INTEGER, 33L))), - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 33L), ImmutableList.of( new WhenClause(new Constant(INTEGER, null), new Constant(INTEGER, 1L))), Optional.of(new Constant(INTEGER, 33L))), - new SimpleCaseExpression( - new SymbolReference(INTEGER, "bound_value"), + new Switch( + new Reference(INTEGER, "bound_value"), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1234L), new Constant(INTEGER, 33L))), Optional.empty()), - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1234L), ImmutableList.of( - new WhenClause(new SymbolReference(INTEGER, "bound_value"), new Constant(INTEGER, 33L))), + new WhenClause(new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 33L))), Optional.empty()))) { assertOptimizedEquals( - simpleCaseExpression, + aSwitch, new Constant(INTEGER, 33L)); } assertOptimizedEquals( - new SimpleCaseExpression( - TRUE_LITERAL, + new Switch( + TRUE, ImmutableList.of( - new WhenClause(TRUE_LITERAL, new SymbolReference(INTEGER, "bound_value"))), + new WhenClause(TRUE, new Reference(INTEGER, "bound_value"))), Optional.empty()), new Constant(INTEGER, 1234L)); assertOptimizedEquals( - new SimpleCaseExpression( - TRUE_LITERAL, + new Switch( + TRUE, ImmutableList.of( - new WhenClause(FALSE_LITERAL, new Constant(INTEGER, 1L))), - Optional.of(new SymbolReference(INTEGER, "bound_value"))), + new WhenClause(FALSE, new Constant(INTEGER, 1L))), + Optional.of(new Reference(INTEGER, "bound_value"))), new Constant(INTEGER, 1234L)); assertOptimizedEquals( - new SimpleCaseExpression( - TRUE_LITERAL, + new Switch( + TRUE, ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 33L))), - new SimpleCaseExpression( - TRUE_LITERAL, + new Switch( + TRUE, ImmutableList.of( - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), - new WhenClause(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), + new WhenClause(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 33L)))); assertOptimizedMatches( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 1L))), - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 1L)))); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, null), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), Optional.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, null), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); assertOptimizedEquals( - new SimpleCaseExpression( - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), + new Switch( + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L))), - new SimpleCaseExpression( - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), + new Switch( + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L)))); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L))), - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.of(new Constant(INTEGER, 3L)))); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 3L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 3L))), Optional.of(new Constant(INTEGER, 4L))), - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 3L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 3L))), Optional.of(new Constant(INTEGER, 4L)))); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.empty()), - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), Optional.empty())); assertOptimizedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), @@ -702,33 +702,33 @@ public void testSimpleCase() new Constant(UNKNOWN, null)); assertEvaluatedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, null), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), Optional.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( - new WhenClause(new Constant(INTEGER, 2L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Constant(INTEGER, 2L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), Optional.of(new Constant(INTEGER, 3L))), new Constant(INTEGER, 3L)); assertEvaluatedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), - new WhenClause(new Constant(INTEGER, 1L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + new WhenClause(new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), Optional.empty()), new Constant(INTEGER, 2L)); assertEvaluatedEquals( - new SimpleCaseExpression( + new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), + Optional.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 2L)); } @@ -736,47 +736,47 @@ public void testSimpleCase() public void testCoalesce() { assertOptimizedEquals( - new CoalesceExpression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "unbound_value"), new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L))), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)), new Constant(INTEGER, null)), - new CoalesceExpression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "unbound_value"), new Constant(INTEGER, 6L)), new Constant(INTEGER, 0L))); + new Coalesce(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "unbound_value"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L))), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)), new Constant(INTEGER, null)), + new Coalesce(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 6L)), new Constant(INTEGER, 0L))); assertOptimizedMatches( - new CoalesceExpression(new SymbolReference(INTEGER, "unbound_value"), new SymbolReference(INTEGER, "unbound_value")), - new SymbolReference(INTEGER, "unbound_value")); + new Coalesce(new Reference(INTEGER, "unbound_value"), new Reference(INTEGER, "unbound_value")), + new Reference(INTEGER, "unbound_value")); assertOptimizedEquals( - new CoalesceExpression(new Constant(INTEGER, 6L), new SymbolReference(INTEGER, "unbound_value")), + new Coalesce(new Constant(INTEGER, 6L), new Reference(INTEGER, "unbound_value")), new Constant(INTEGER, 6L)); assertOptimizedMatches( - new CoalesceExpression(new FunctionCall(RANDOM, ImmutableList.of()), new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 5.0)), - new CoalesceExpression(new FunctionCall(RANDOM, ImmutableList.of()), new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 5.0))); + new Coalesce(new Call(RANDOM, ImmutableList.of()), new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 5.0)), + new Coalesce(new Call(RANDOM, ImmutableList.of()), new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 5.0))); assertOptimizedEquals( - new CoalesceExpression(new Constant(UNKNOWN, null), new CoalesceExpression(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), + new Coalesce(new Constant(UNKNOWN, null), new Coalesce(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new CoalesceExpression(new Constant(INTEGER, null), new CoalesceExpression(new Constant(INTEGER, null), new CoalesceExpression(new Constant(INTEGER, null), new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), + new Coalesce(new Constant(INTEGER, null), new Coalesce(new Constant(INTEGER, null), new Coalesce(new Constant(INTEGER, null), new Constant(INTEGER, null), new Constant(INTEGER, 1L)))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new CoalesceExpression(new Constant(INTEGER, 1L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Coalesce(new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), - new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))); + new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, null)), - new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))); + new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, null)), + new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new CoalesceExpression(new Constant(INTEGER, 1L), new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), + new Coalesce(new Constant(INTEGER, 1L), new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, null), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)), new Constant(INTEGER, null), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)))); + new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)), new Constant(INTEGER, null), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 1L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( - new CoalesceExpression(new FunctionCall(RANDOM, ImmutableList.of()), new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0), new FunctionCall(RANDOM, ImmutableList.of())), - new CoalesceExpression(new FunctionCall(RANDOM, ImmutableList.of()), new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0))); + new Coalesce(new Call(RANDOM, ImmutableList.of()), new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0), new Call(RANDOM, ImmutableList.of())), + new Coalesce(new Call(RANDOM, ImmutableList.of()), new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0))); assertEvaluatedEquals( - new CoalesceExpression(new Constant(INTEGER, 1L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Coalesce(new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)); - assertTrinoExceptionThrownBy(() -> evaluate(new CoalesceExpression(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new Coalesce(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)))) .hasErrorCode(DIVISION_BY_ZERO); } @@ -784,64 +784,64 @@ public void testCoalesce() public void testIf() { assertOptimizedEquals( - ifExpression(new ComparisonExpression(EQUAL, new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + ifExpression(new Comparison(EQUAL, new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), new Constant(INTEGER, 3L)); assertOptimizedEquals( - ifExpression(new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + ifExpression(new Comparison(EQUAL, new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), new Constant(INTEGER, 4L)); assertOptimizedEquals( - ifExpression(TRUE_LITERAL, new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + ifExpression(TRUE, new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), new Constant(INTEGER, 3L)); assertOptimizedEquals( - ifExpression(FALSE_LITERAL, new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), + ifExpression(FALSE, new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), new Constant(INTEGER, 4L)); assertOptimizedEquals( ifExpression(new Constant(BOOLEAN, null), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)), new Constant(INTEGER, 4L)); assertOptimizedEquals( - ifExpression(TRUE_LITERAL, new Constant(INTEGER, 3L), new Constant(INTEGER, null)), + ifExpression(TRUE, new Constant(INTEGER, 3L), new Constant(INTEGER, null)), new Constant(INTEGER, 3L)); assertOptimizedEquals( - ifExpression(FALSE_LITERAL, new Constant(INTEGER, 3L), new Constant(INTEGER, null)), + ifExpression(FALSE, new Constant(INTEGER, 3L), new Constant(INTEGER, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - ifExpression(TRUE_LITERAL, new Constant(INTEGER, null), new Constant(INTEGER, 4L)), + ifExpression(TRUE, new Constant(INTEGER, null), new Constant(INTEGER, 4L)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - ifExpression(FALSE_LITERAL, new Constant(INTEGER, null), new Constant(INTEGER, 4L)), + ifExpression(FALSE, new Constant(INTEGER, null), new Constant(INTEGER, 4L)), new Constant(INTEGER, 4L)); assertOptimizedEquals( - ifExpression(TRUE_LITERAL, new Constant(INTEGER, null), new Constant(INTEGER, null)), + ifExpression(TRUE, new Constant(INTEGER, null), new Constant(INTEGER, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - ifExpression(FALSE_LITERAL, new Constant(INTEGER, null), new Constant(INTEGER, null)), + ifExpression(FALSE, new Constant(INTEGER, null), new Constant(INTEGER, null)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - ifExpression(TRUE_LITERAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + ifExpression(TRUE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); assertOptimizedEquals( - ifExpression(TRUE_LITERAL, new Constant(INTEGER, 1L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + ifExpression(TRUE, new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)); assertOptimizedEquals( - ifExpression(FALSE_LITERAL, new Constant(INTEGER, 1L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + ifExpression(FALSE, new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); assertOptimizedEquals( - ifExpression(FALSE_LITERAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + ifExpression(FALSE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); assertOptimizedEquals( - ifExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), - ifExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))); + ifExpression(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), + ifExpression(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))); assertEvaluatedEquals( - ifExpression(TRUE_LITERAL, new Constant(INTEGER, 1L), new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), + ifExpression(TRUE, new Constant(INTEGER, 1L), new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( - ifExpression(FALSE_LITERAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), + ifExpression(FALSE, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); - assertTrinoExceptionThrownBy(() -> evaluate(ifExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) + assertTrinoExceptionThrownBy(() -> evaluate(ifExpression(new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) .hasErrorCode(DIVISION_BY_ZERO); } @@ -849,10 +849,10 @@ public void testIf() public void testOptimizeDivideByZero() { assertOptimizedEquals( - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), - new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), + new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))); - assertTrinoExceptionThrownBy(() -> evaluate(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))) .hasErrorCode(DIVISION_BY_ZERO); } @@ -860,14 +860,14 @@ public void testOptimizeDivideByZero() public void testRowSubscript() { assertOptimizedEquals( - new SubscriptExpression(BOOLEAN, new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")), TRUE_LITERAL)), new Constant(INTEGER, 3L)), - TRUE_LITERAL); + new Subscript(BOOLEAN, new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")), TRUE)), new Constant(INTEGER, 3L)), + TRUE); assertOptimizedEquals( - new SubscriptExpression( + new Subscript( VARCHAR, - new SubscriptExpression( + new Subscript( anonymousRow(INTEGER, VARCHAR), - new SubscriptExpression( + new Subscript( anonymousRow(INTEGER, VARCHAR, anonymousRow(INTEGER, VARCHAR)), new Row(ImmutableList.of( new Constant(INTEGER, 1L), @@ -882,18 +882,18 @@ public void testRowSubscript() new Constant(VARCHAR, Slices.utf8Slice("c"))); assertOptimizedEquals( - new SubscriptExpression(UNKNOWN, new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(UNKNOWN, null))), new Constant(INTEGER, 2L)), + new Subscript(UNKNOWN, new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(UNKNOWN, null))), new Constant(INTEGER, 2L)), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)), - new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L))); + new Subscript(INTEGER, new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)), + new Subscript(INTEGER, new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L))); assertOptimizedEquals( - new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)), - new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L))); + new Subscript(INTEGER, new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)), + new Subscript(INTEGER, new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L))); - assertTrinoExceptionThrownBy(() -> evaluate(new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new Subscript(INTEGER, new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))) .hasErrorCode(DIVISION_BY_ZERO); - assertTrinoExceptionThrownBy(() -> evaluate(new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new Subscript(INTEGER, new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))) .hasErrorCode(DIVISION_BY_ZERO); } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java index 6eec0d4ea76c..6b6ed4c90714 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java @@ -15,14 +15,14 @@ import com.google.common.collect.ImmutableList; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.combineConjuncts; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.AND; import static org.assertj.core.api.Assertions.assertThat; public class TestExpressionUtils @@ -30,14 +30,14 @@ public class TestExpressionUtils @Test public void testAnd() { - Expression a = new SymbolReference(BOOLEAN, "a"); - Expression b = new SymbolReference(BOOLEAN, "b"); - Expression c = new SymbolReference(BOOLEAN, "c"); - Expression d = new SymbolReference(BOOLEAN, "d"); - Expression e = new SymbolReference(BOOLEAN, "e"); + Expression a = new Reference(BOOLEAN, "a"); + Expression b = new Reference(BOOLEAN, "b"); + Expression c = new Reference(BOOLEAN, "c"); + Expression d = new Reference(BOOLEAN, "d"); + Expression e = new Reference(BOOLEAN, "e"); - assertThat(and(a, b, c, d, e)).isEqualTo(new LogicalExpression(AND, ImmutableList.of(a, b, c, d, e))); + assertThat(and(a, b, c, d, e)).isEqualTo(new Logical(AND, ImmutableList.of(a, b, c, d, e))); - assertThat(combineConjuncts(a, b, a, c, d, c, e)).isEqualTo(new LogicalExpression(AND, ImmutableList.of(a, b, c, d, e))); + assertThat(combineConjuncts(a, b, a, c, d, c, e)).isEqualTo(new Logical(AND, ImmutableList.of(a, b, c, d, e))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java index 75c56d5d3c3d..db77a79acc9d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.type.Decimals; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.IrExpressionInterpreter; @@ -44,7 +44,7 @@ public void testPossibleExponentialOptimizationTime() { Expression expression = new Constant(BIGINT, 1L); for (int i = 0; i < 100; i++) { - expression = new CoalesceExpression(expression, new Constant(BIGINT, 2L)); + expression = new Coalesce(expression, new Constant(BIGINT, 2L)); } translateAndOptimize(expression); } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java index 3472b17ddee8..d9fdb7acf61d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java @@ -31,13 +31,13 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SqlToRowExpressionTranslator; @@ -69,9 +69,9 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.type.UnknownType.UNKNOWN; import static java.util.Locale.ENGLISH; @@ -163,10 +163,10 @@ public List> columnOriented() private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new Cast(new SymbolReference(VARCHAR, "varchar0"), BIGINT), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Cast(new Reference(VARCHAR, "varchar0"), BIGINT), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); } if (type == BIGINT) { - return rowExpression(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(INTEGER, "bigint0"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); + return rowExpression(new Comparison(EQUAL, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L))); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -176,14 +176,14 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L)))); + builder.add(rowExpression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "bigint" + i), new Constant(BIGINT, 5L)))); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression(new FunctionCall(CONCAT, ImmutableList.of(new SymbolReference(VARCHAR, "varchar" + i), new Constant(VARCHAR, Slices.utf8Slice("foo")))))); + builder.add(rowExpression(new Call(CONCAT, ImmutableList.of(new Reference(VARCHAR, "varchar" + i), new Constant(VARCHAR, Slices.utf8Slice("foo")))))); } } return builder.build(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java index 5353458bd77d..505dbae5849d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java @@ -20,14 +20,14 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.optimizations.PlanOptimizer; @@ -47,17 +47,17 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -106,7 +106,7 @@ public void testPushDownToLhsOfSemiJoin() semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "LINE_NUMBER"), new Constant(INTEGER, 2L)), + new Comparison(EQUAL, new Reference(INTEGER, "LINE_NUMBER"), new Constant(INTEGER, 2L)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey", "LINE_NUMBER", "linenumber", @@ -122,7 +122,7 @@ public void testNonDeterministicPredicatePropagatesOnlyToSourceSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new Comparison(EQUAL, new Reference(BIGINT, "LINE_ORDER_KEY"), new Cast(new Call(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey"))), node(ExchangeNode.class, // NO filter here @@ -132,7 +132,7 @@ public void testNonDeterministicPredicatePropagatesOnlyToSourceSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new Comparison(EQUAL, new Reference(BIGINT, "LINE_ORDER_KEY"), new Cast(new Call(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey"))), anyTree( @@ -146,13 +146,13 @@ public void testGreaterPredicateFromFilterSidePropagatesToSourceSideOfSemiJoin() noSemiJoinRewrite(), anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, - filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey", "LINE_QUANTITY", "quantity"))), anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -164,13 +164,13 @@ public void testEqualsPredicateFromFilterSidePropagatesToSourceSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey", "LINE_QUANTITY", "quantity"))), anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -187,7 +187,7 @@ public void testPredicateFromFilterSideNotPropagatesToSourceSideOfSemiJoinIfNotI "LINE_QUANTITY", "quantity")), anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -199,13 +199,13 @@ public void testGreaterPredicateFromSourceSidePropagatesToFilterSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey", "LINE_QUANTITY", "quantity"))), anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -217,13 +217,13 @@ public void testEqualPredicateFromSourceSidePropagatesToFilterSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey", "LINE_QUANTITY", "quantity"))), anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -234,7 +234,7 @@ public void testPredicateFromSourceSideNotPropagatesToFilterSideOfSemiJoinIfNotI anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "LINE_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey", "LINE_QUANTITY", "quantity"))), @@ -253,7 +253,7 @@ public void testPredicateFromFilterSideNotPropagatesToSourceSideOfSemiJoinUsedIn "LINE_ORDER_KEY", "orderkey")), anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Constant(BIGINT, 2L)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -294,9 +294,9 @@ public void testPredicatePushDownThroughMarkDistinct() join(LEFT, builder -> builder .equiCriteria("A", "B") .left( - assignUniqueId("unique", filter(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "A"), new Constant(INTEGER, 1L)), values("A")))) + assignUniqueId("unique", filter(new Comparison(EQUAL, new Reference(INTEGER, "A"), new Constant(INTEGER, 1L)), values("A")))) .right( - filter(new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "B")), values("B")))))); + filter(new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "B")), values("B")))))); } @Test @@ -308,8 +308,8 @@ public void testPredicatePushDownOverProjection() "SELECT * FROM t WHERE x + x > 1", anyTree( filter( - new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "expr"), new SymbolReference(BIGINT, "expr")), new Constant(BIGINT, 1L)), - project(ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)))), + new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "expr"), new Reference(BIGINT, "expr")), new Constant(BIGINT, 1L)), + project(ImmutableMap.of("expr", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)))), tableScan("orders", ImmutableMap.of("orderkey", "orderkey")))))); // constant non-singleton should be pushed down @@ -319,7 +319,7 @@ public void testPredicatePushDownOverProjection() anyTree( project( filter( - new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Constant(BIGINT, 1L)), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Constant(BIGINT, 1L)), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); @@ -330,7 +330,7 @@ public void testPredicatePushDownOverProjection() anyTree( project( filter( - new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); @@ -341,7 +341,7 @@ public void testPredicatePushDownOverProjection() anyTree( project( filter( - new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new SymbolReference(BIGINT, "orderkey")), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "orderkey")), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); @@ -351,7 +351,7 @@ public void testPredicatePushDownOverProjection() "SELECT * FROM t WHERE x >1", anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 1L)), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey"))))); @@ -361,8 +361,8 @@ public void testPredicatePushDownOverProjection() "SELECT * FROM t WHERE x > 5000", anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "expr"), new Constant(DOUBLE, 5000.0)), - project(ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(MULTIPLY_DOUBLE, MULTIPLY, new FunctionCall(RANDOM, ImmutableList.of()), new Cast(new SymbolReference(BIGINT, "orderkey"), DOUBLE)))), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "expr"), new Constant(DOUBLE, 5000.0)), + project(ImmutableMap.of("expr", expression(new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Call(RANDOM, ImmutableList.of()), new Cast(new Reference(BIGINT, "orderkey"), DOUBLE)))), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); } @@ -376,7 +376,7 @@ public void testPredicatePushDownOverSymbolReferences() "SELECT * FROM t WHERE x > 1 OR x < 0", anyTree( filter( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 1L)))), + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 1L)))), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey"))))); } @@ -398,7 +398,7 @@ public void testConjunctsOrder() // Order matters: size<>1 should be before 100/(size-1)=2. // In this particular example, reversing the order leads to div-by-zero error. filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "size"), new Constant(INTEGER, 1L)), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 100L), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "size"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "size"), new Constant(INTEGER, 1L)), new Comparison(EQUAL, new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 100L), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "size"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))), tableScan("part", ImmutableMap.of( "partkey", "partkey", "size", "size"))))); @@ -419,12 +419,12 @@ public void testPredicateOnPartitionSymbolsPushedThroughWindow() ") WHERE custkey = 0 AND orderkey > 0", anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ORDER_KEY"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "ORDER_KEY"), new Constant(BIGINT, 0L)), anyTree( node(WindowNode.class, anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "CUST_KEY"), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Reference(BIGINT, "CUST_KEY"), new Constant(BIGINT, 0L)), tableScan))))))); } @@ -440,8 +440,8 @@ public void testPredicateOnNonDeterministicSymbolsPushedDown() node(WindowNode.class, anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "ROUND"), new Constant(DOUBLE, 100.0)), - project(ImmutableMap.of("ROUND", expression(new FunctionCall(ROUND, ImmutableList.of(new ArithmeticBinaryExpression(MULTIPLY_DOUBLE, MULTIPLY, new Cast(new SymbolReference(BIGINT, "CUST_KEY"), DOUBLE), new FunctionCall(RANDOM, ImmutableList.of())))))), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "ROUND"), new Constant(DOUBLE, 100.0)), + project(ImmutableMap.of("ROUND", expression(new Call(ROUND, ImmutableList.of(new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Cast(new Reference(BIGINT, "CUST_KEY"), DOUBLE), new Call(RANDOM, ImmutableList.of())))))), tableScan( "orders", ImmutableMap.of("CUST_KEY", "custkey")))))))); @@ -457,7 +457,7 @@ public void testNonDeterministicPredicateNotPushedDown() ") WHERE custkey > 100*rand()", anyTree( filter( - new ComparisonExpression(GREATER_THAN, new Cast(new SymbolReference(BIGINT, "CUST_KEY"), DOUBLE), new ArithmeticBinaryExpression(MULTIPLY_DOUBLE, MULTIPLY, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 100.0))), + new Comparison(GREATER_THAN, new Cast(new Reference(BIGINT, "CUST_KEY"), DOUBLE), new Arithmetic(MULTIPLY_DOUBLE, MULTIPLY, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 100.0))), anyTree( node(WindowNode.class, anyTree( @@ -478,7 +478,7 @@ public void testRemovesRedundantTableScanPredicate() JoinNode.class, node(ProjectNode.class, filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "ORDERKEY"), new Constant(BIGINT, 123L)), new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Cast(new SymbolReference(BIGINT, "ORDERKEY"), DOUBLE)), new ComparisonExpression(LESS_THAN, new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference(createVarcharType(1), "ORDERSTATUS"))), new Constant(BIGINT, 42L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "ORDERKEY"), new Constant(BIGINT, 123L)), new Comparison(EQUAL, new Call(RANDOM, ImmutableList.of()), new Cast(new Reference(BIGINT, "ORDERKEY"), DOUBLE)), new Comparison(LESS_THAN, new Call(LENGTH, ImmutableList.of(new Reference(createVarcharType(1), "ORDERSTATUS"))), new Constant(BIGINT, 42L)))), tableScan( "orders", ImmutableMap.of( @@ -496,11 +496,11 @@ public void testTablePredicateIsExtracted() anyTree( node(JoinNode.class, filter( - new InPredicate(new SymbolReference(createVarcharType(1), "ORDERSTATUS"), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("F")), new Constant(createVarcharType(1), Slices.utf8Slice("O")))), + new In(new Reference(createVarcharType(1), "ORDERSTATUS"), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("F")), new Constant(createVarcharType(1), Slices.utf8Slice("O")))), tableScan("orders", ImmutableMap.of("ORDERSTATUS", "orderstatus"))), anyTree( filter( - new InPredicate(new Cast(new SymbolReference(VARCHAR, "NAME"), createVarcharType(1)), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("F")), new Constant(createVarcharType(1), Slices.utf8Slice("O")))), + new In(new Cast(new Reference(VARCHAR, "NAME"), createVarcharType(1)), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("F")), new Constant(createVarcharType(1), Slices.utf8Slice("O")))), tableScan( "nation", ImmutableMap.of("NAME", "name"))))))); @@ -510,10 +510,10 @@ public void testTablePredicateIsExtracted() "SELECT * FROM orders JOIN nation ON orderstatus = CAST(nation.name AS varchar(1))", anyTree( node(JoinNode.class, - enableDynamicFiltering ? filter(TRUE_LITERAL, ordersTableScan) : ordersTableScan, + enableDynamicFiltering ? filter(TRUE, ordersTableScan) : ordersTableScan, anyTree( filter( - new InPredicate(new Cast(new SymbolReference(VARCHAR, "NAME"), createVarcharType(1)), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("F")), new Constant(createVarcharType(1), Slices.utf8Slice("O")), new Constant(createVarcharType(1), Slices.utf8Slice("P")))), + new In(new Cast(new Reference(VARCHAR, "NAME"), createVarcharType(1)), ImmutableList.of(new Constant(createVarcharType(1), Slices.utf8Slice("F")), new Constant(createVarcharType(1), Slices.utf8Slice("O")), new Constant(createVarcharType(1), Slices.utf8Slice("P")))), tableScan( "nation", ImmutableMap.of("NAME", "name"))))))); @@ -541,12 +541,12 @@ public void testSimplifyNonInferrableInheritedPredicate() .equiCriteria(ImmutableList.of()) .left( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "L_NATIONKEY"), new SymbolReference(BIGINT, "L_REGIONKEY")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "L_REGIONKEY"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "L_NATIONKEY"), new Reference(BIGINT, "L_REGIONKEY")), new Comparison(EQUAL, new Reference(BIGINT, "L_REGIONKEY"), new Constant(BIGINT, 5L)))), tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey")))) .right( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "R_NATIONKEY"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "R_NATIONKEY"), new Constant(BIGINT, 5L)), tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); } @@ -558,7 +558,7 @@ public void testDoesNotCreatePredicateFromInferredPredicate() join(INNER, builder -> builder .equiCriteria("L_NATIONKEY2", "R_NATIONKEY") .left( - project(ImmutableMap.of("L_NATIONKEY2", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 1L)))), + project(ImmutableMap.of("L_NATIONKEY2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 1L)))), tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) .right( anyTree( @@ -570,12 +570,12 @@ public void testDoesNotCreatePredicateFromInferredPredicate() .equiCriteria(ImmutableList.of()) .left( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 5L)), tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) .right( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "R_NATIONKEY"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "R_NATIONKEY"), new Constant(BIGINT, 5L)), tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); } @@ -585,17 +585,17 @@ public void testSimplifiesStraddlingPredicate() assertPlan("SELECT * FROM (SELECT * FROM NATION WHERE nationkey = 5) a JOIN nation b ON a.nationkey = b.nationkey AND a.nationkey = a.regionkey + b.regionkey", output( filter( - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "L_REGIONKEY"), new SymbolReference(BIGINT, "R_REGIONKEY")), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_REGIONKEY"), new Reference(BIGINT, "R_REGIONKEY")), new Constant(BIGINT, 5L)), join(INNER, builder -> builder .equiCriteria(ImmutableList.of()) .left( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "L_NATIONKEY"), new Constant(BIGINT, 5L)), tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey")))) .right( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "R_NATIONKEY"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "R_NATIONKEY"), new Constant(BIGINT, 5L)), tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey", "R_REGIONKEY", "regionkey"))))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java index 83edfe494c77..4d1429707be8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestAddDynamicFilterSource.java @@ -24,12 +24,12 @@ import io.trino.operator.RetryPolicy; import io.trino.spi.function.OperatorType; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; @@ -53,13 +53,13 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.BROADCAST; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED; import static io.trino.sql.planner.assertions.PlanMatchPattern.DynamicFilterPattern; @@ -160,7 +160,7 @@ public void testSemiJoin() noSemiJoinRewrite(joinDistributionType), anyTree( filter( - new SymbolReference(BOOLEAN, "S"), + new Reference(BOOLEAN, "S"), semiJoin("X", "Y", "S", Optional.of(semiJoinDistributionType), Optional.of(true), node( FilterNode.class, @@ -175,7 +175,7 @@ public void testSemiJoin() DynamicFilterSourceNode.class, project( filter( - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "Z"), new Constant(INTEGER, 4L)), new Constant(INTEGER, 0L)), + new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "Z"), new Constant(INTEGER, 4L)), new Constant(INTEGER, 0L)), tableScan("lineitem", ImmutableMap.of("Y", "orderkey", "Z", "linenumber"))))))))))); } } @@ -244,14 +244,14 @@ public void testCrossJoinInequality() "SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey BETWEEN l.orderkey AND l.partkey", anyTree( filter( - new BetweenPredicate(new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"), new SymbolReference(BIGINT, "L_PARTKEY")), + new Between(new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"), new Reference(BIGINT, "L_PARTKEY")), join(INNER, builder -> builder .dynamicFilter(ImmutableList.of( - new DynamicFilterPattern(new SymbolReference(BIGINT, "L_ORDERKEY"), LESS_THAN_OR_EQUAL, "O_ORDERKEY"), - new DynamicFilterPattern(new SymbolReference(BIGINT, "L_PARTKEY"), GREATER_THAN_OR_EQUAL, "O_ORDERKEY"))) + new DynamicFilterPattern(new Reference(BIGINT, "L_ORDERKEY"), LESS_THAN_OR_EQUAL, "O_ORDERKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "L_PARTKEY"), GREATER_THAN_OR_EQUAL, "O_ORDERKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey", "L_PARTKEY", "partkey")))) .right( exchange( @@ -267,7 +267,7 @@ public void testCrossJoinInequality() withJoinDistributionType(PARTITIONED), anyTree( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "expr")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "expr")))), join(INNER, builder -> builder .left( tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey"))) @@ -275,7 +275,7 @@ public void testCrossJoinInequality() exchange( LOCAL, project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L)))), exchange( REMOTE, tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey", "L_PARTKEY", "partkey")))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index 9e19f3ff4043..cba3605f740b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -22,7 +22,6 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.scalar.JsonPath; import io.trino.security.AllowAllAccessControl; -import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.FunctionName; @@ -32,21 +31,21 @@ import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.testing.TestingSession; import io.trino.transaction.TestingTransactionManager; import io.trino.transaction.TransactionManager; @@ -138,16 +137,16 @@ private void testTranslateConstant(Object nativeValue, Type type) @Test public void testTranslateSymbol() { - assertTranslationRoundTrips(new SymbolReference(DOUBLE, "double_symbol_1"), new Variable("double_symbol_1", DOUBLE)); + assertTranslationRoundTrips(new Reference(DOUBLE, "double_symbol_1"), new Variable("double_symbol_1", DOUBLE)); } @Test public void testTranslateRowSubscript() { assertTranslationRoundTrips( - new SubscriptExpression( + new Subscript( INTEGER, - new SymbolReference(ROW_TYPE, "row_symbol_1"), + new Reference(ROW_TYPE, "row_symbol_1"), new Constant(INTEGER, 1L)), new FieldDereference( INTEGER, @@ -158,22 +157,22 @@ public void testTranslateRowSubscript() @Test public void testTranslateLogicalExpression() { - for (LogicalExpression.Operator operator : LogicalExpression.Operator.values()) { + for (Logical.Operator operator : Logical.Operator.values()) { assertTranslationRoundTrips( - new LogicalExpression( + new Logical( operator, List.of( - new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference(DOUBLE, "double_symbol_1"), new SymbolReference(DOUBLE, "double_symbol_2")), - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(DOUBLE, "double_symbol_1"), new SymbolReference(DOUBLE, "double_symbol_2")))), - new Call( + new Comparison(Comparison.Operator.LESS_THAN, new Reference(DOUBLE, "double_symbol_1"), new Reference(DOUBLE, "double_symbol_2")), + new Comparison(Comparison.Operator.EQUAL, new Reference(DOUBLE, "double_symbol_1"), new Reference(DOUBLE, "double_symbol_2")))), + new io.trino.spi.expression.Call( BOOLEAN, - operator == LogicalExpression.Operator.AND ? StandardFunctions.AND_FUNCTION_NAME : StandardFunctions.OR_FUNCTION_NAME, + operator == Logical.Operator.AND ? StandardFunctions.AND_FUNCTION_NAME : StandardFunctions.OR_FUNCTION_NAME, List.of( - new Call( + new io.trino.spi.expression.Call( BOOLEAN, StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE))), - new Call( + new io.trino.spi.expression.Call( BOOLEAN, StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE)))))); @@ -183,10 +182,10 @@ public void testTranslateLogicalExpression() @Test public void testTranslateComparisonExpression() { - for (ComparisonExpression.Operator operator : ComparisonExpression.Operator.values()) { + for (Comparison.Operator operator : Comparison.Operator.values()) { assertTranslationRoundTrips( - new ComparisonExpression(operator, new SymbolReference(DOUBLE, "double_symbol_1"), new SymbolReference(DOUBLE, "double_symbol_2")), - new Call( + new Comparison(operator, new Reference(DOUBLE, "double_symbol_1"), new Reference(DOUBLE, "double_symbol_2")), + new io.trino.spi.expression.Call( BOOLEAN, ConnectorExpressionTranslator.functionNameForComparisonOperator(operator), List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE)))); @@ -197,9 +196,9 @@ public void testTranslateComparisonExpression() public void testTranslateArithmeticBinary() { TestingFunctionResolution resolver = new TestingFunctionResolution(); - for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) { + for (Arithmetic.Operator operator : Arithmetic.Operator.values()) { assertTranslationRoundTrips( - new ArithmeticBinaryExpression( + new Arithmetic( resolver.resolveOperator( switch (operator) { case ADD -> OperatorType.ADD; @@ -210,9 +209,9 @@ public void testTranslateArithmeticBinary() }, ImmutableList.of(DOUBLE, DOUBLE)), operator, - new SymbolReference(DOUBLE, "double_symbol_1"), - new SymbolReference(DOUBLE, "double_symbol_2")), - new Call( + new Reference(DOUBLE, "double_symbol_1"), + new Reference(DOUBLE, "double_symbol_2")), + new io.trino.spi.expression.Call( DOUBLE, ConnectorExpressionTranslator.functionNameForArithmeticBinaryOperator(operator), List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE)))); @@ -223,8 +222,8 @@ public void testTranslateArithmeticBinary() public void testTranslateArithmeticUnaryMinus() { assertTranslationRoundTrips( - new ArithmeticNegation(new SymbolReference(DOUBLE, "double_symbol_1")), - new Call(DOUBLE, NEGATE_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE)))); + new Negation(new Reference(DOUBLE, "double_symbol_1")), + new io.trino.spi.expression.Call(DOUBLE, NEGATE_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE)))); } @Test @@ -232,21 +231,21 @@ public void testTranslateBetween() { assertTranslationToConnectorExpression( TEST_SESSION, - new BetweenPredicate( - new SymbolReference(DOUBLE, "double_symbol_1"), + new Between( + new Reference(DOUBLE, "double_symbol_1"), new Constant(DOUBLE, 1.2), - new SymbolReference(DOUBLE, "double_symbol_2")), - new Call( + new Reference(DOUBLE, "double_symbol_2")), + new io.trino.spi.expression.Call( BOOLEAN, AND_FUNCTION_NAME, List.of( - new Call( + new io.trino.spi.expression.Call( BOOLEAN, GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, List.of( new Variable("double_symbol_1", DOUBLE), new io.trino.spi.expression.Constant(1.2d, DOUBLE))), - new Call( + new io.trino.spi.expression.Call( BOOLEAN, LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, List.of( @@ -258,8 +257,8 @@ public void testTranslateBetween() public void testTranslateIsNull() { assertTranslationRoundTrips( - new IsNullPredicate(new SymbolReference(VARCHAR, "varchar_symbol_1")), - new Call( + new IsNull(new Reference(VARCHAR, "varchar_symbol_1")), + new io.trino.spi.expression.Call( BOOLEAN, IS_NULL_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); @@ -269,8 +268,8 @@ public void testTranslateIsNull() public void testTranslateNotExpression() { assertTranslationRoundTrips( - new NotExpression(new SymbolReference(BOOLEAN, "boolean_symbol_1")), - new Call( + new Not(new Reference(BOOLEAN, "boolean_symbol_1")), + new io.trino.spi.expression.Call( BOOLEAN, NOT_FUNCTION_NAME, List.of(new Variable("boolean_symbol_1", BOOLEAN)))); @@ -280,19 +279,19 @@ public void testTranslateNotExpression() public void testTranslateIsNotNull() { assertTranslationRoundTrips( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "varchar_symbol_1"))), - new Call( + new Not(new IsNull(new Reference(VARCHAR, "varchar_symbol_1"))), + new io.trino.spi.expression.Call( BOOLEAN, NOT_FUNCTION_NAME, - List.of(new Call(BOOLEAN, IS_NULL_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))))); + List.of(new io.trino.spi.expression.Call(BOOLEAN, IS_NULL_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))))); } @Test public void testTranslateCast() { assertTranslationRoundTrips( - new Cast(new SymbolReference(VARCHAR, "varchar_symbol_1"), VARCHAR_TYPE), - new Call( + new Cast(new Reference(VARCHAR, "varchar_symbol_1"), VARCHAR_TYPE), + new io.trino.spi.expression.Call( VARCHAR_TYPE, CAST_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); @@ -301,7 +300,7 @@ public void testTranslateCast() assertTranslationToConnectorExpression( TEST_SESSION, new Cast( - new SymbolReference(VARCHAR, "varchar_symbol_1"), + new Reference(VARCHAR, "varchar_symbol_1"), BIGINT, true), Optional.empty()); @@ -316,7 +315,7 @@ public void testTranslateLike() .readOnly() .execute(TEST_SESSION, transactionSession -> { String pattern = "%pattern%"; - Call translated = new Call(BOOLEAN, + io.trino.spi.expression.Call translated = new io.trino.spi.expression.Call(BOOLEAN, StandardFunctions.LIKE_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new io.trino.spi.expression.Constant(Slices.wrappedBuffer(pattern.getBytes(UTF_8)), createVarcharType(pattern.length())))); @@ -324,7 +323,7 @@ public void testTranslateLike() assertTranslationToConnectorExpression( transactionSession, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(LIKE_PATTERN, new Constant(LIKE_PATTERN, likePattern(utf8Slice(pattern)))) .build(), Optional.of(translated)); @@ -333,7 +332,7 @@ public void testTranslateLike() transactionSession, translated, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(LIKE_PATTERN, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) .setName(LikeFunctions.LIKE_PATTERN_FUNCTION_NAME) @@ -342,7 +341,7 @@ public void testTranslateLike() .build()); String escape = "\\"; - translated = new Call(BOOLEAN, + translated = new io.trino.spi.expression.Call(BOOLEAN, StandardFunctions.LIKE_FUNCTION_NAME, List.of( new Variable("varchar_symbol_1", VARCHAR_TYPE), @@ -352,7 +351,7 @@ public void testTranslateLike() assertTranslationToConnectorExpression( transactionSession, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(LIKE_PATTERN, new Constant(LIKE_PATTERN, likePattern(utf8Slice(pattern), utf8Slice(escape)))) .build(), Optional.of(translated)); @@ -361,7 +360,7 @@ public void testTranslateLike() transactionSession, translated, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName(LikeFunctions.LIKE_FUNCTION_NAME).addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(LIKE_PATTERN, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) .setName(LikeFunctions.LIKE_PATTERN_FUNCTION_NAME) @@ -376,10 +375,10 @@ public void testTranslateLike() public void testTranslateNullIf() { assertTranslationRoundTrips( - new NullIfExpression( - new SymbolReference(VARCHAR, "varchar_symbol_1"), - new SymbolReference(VARCHAR, "varchar_symbol_1")), - new Call( + new NullIf( + new Reference(VARCHAR, "varchar_symbol_1"), + new Reference(VARCHAR, "varchar_symbol_1")), + new io.trino.spi.expression.Call( VARCHAR_TYPE, NULLIF_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), @@ -397,9 +396,9 @@ public void testTranslateResolvedFunction() assertTranslationRoundTrips( transactionSession, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName("lower").addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName("lower").addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .build(), - new Call(VARCHAR_TYPE, + new io.trino.spi.expression.Call(VARCHAR_TYPE, new FunctionName("lower"), List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); }); @@ -416,18 +415,18 @@ public void testTranslateRegularExpression() transaction(transactionManager, metadata, new AllowAllAccessControl()) .readOnly() .execute(TEST_SESSION, transactionSession -> { - FunctionCall input = BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName("regexp_like").addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + Call input = BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) + .setName("regexp_like").addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(new Constant(JONI_REGEXP, joniRegexp(utf8Slice("a+")))) .build(); - Call translated = new Call( + io.trino.spi.expression.Call translated = new io.trino.spi.expression.Call( BOOLEAN, new FunctionName("regexp_like"), List.of( new Variable("varchar_symbol_1", VARCHAR_TYPE), new io.trino.spi.expression.Constant(utf8Slice("a+"), createVarcharType(2)))); - FunctionCall translatedBack = BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName("regexp_like").addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + Call translatedBack = BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) + .setName("regexp_like").addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) // Note: The result is not an optimized expression .addArgument(JONI_REGEXP, new Cast(new Constant(createVarcharType(2), utf8Slice("a+")), JONI_REGEXP)) .build(); @@ -440,7 +439,7 @@ public void testTranslateRegularExpression() @Test void testTranslateJsonPath() { - Call connectorExpression = new Call( + io.trino.spi.expression.Call connectorExpression = new io.trino.spi.expression.Call( VARCHAR_TYPE, new FunctionName("json_extract_scalar"), List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), @@ -451,7 +450,7 @@ void testTranslateJsonPath() assertTranslationToConnectorExpression( TEST_SESSION, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName("json_extract_scalar").addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName("json_extract_scalar").addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(JSON_PATH, new Constant(JSON_PATH, new JsonPath("$.path"))) .build(), Optional.of(connectorExpression)); @@ -460,7 +459,7 @@ void testTranslateJsonPath() TEST_SESSION, connectorExpression, BuiltinFunctionCallBuilder.resolve(PLANNER_CONTEXT.getMetadata()) - .setName("json_extract_scalar").addArgument(VARCHAR_TYPE, new SymbolReference(VARCHAR_TYPE, "varchar_symbol_1")) + .setName("json_extract_scalar").addArgument(VARCHAR_TYPE, new Reference(VARCHAR_TYPE, "varchar_symbol_1")) .addArgument(JSON_PATH, new Cast(new Constant(createVarcharType(6), utf8Slice("$.path")), JSON_PATH)) .build()); } @@ -470,15 +469,15 @@ public void testTranslateIn() { String value = "value_1"; assertTranslationRoundTrips( - new InPredicate( - new SymbolReference(VARCHAR, "varchar_symbol_1"), - List.of(new SymbolReference(VARCHAR, "varchar_symbol_1"), new Constant(VARCHAR, utf8Slice(value)))), - new Call( + new In( + new Reference(VARCHAR, "varchar_symbol_1"), + List.of(new Reference(VARCHAR, "varchar_symbol_1"), new Constant(VARCHAR, utf8Slice(value)))), + new io.trino.spi.expression.Call( BOOLEAN, StandardFunctions.IN_PREDICATE_FUNCTION_NAME, List.of( new Variable("varchar_symbol_1", VARCHAR_TYPE), - new Call(VARCHAR_ARRAY_TYPE, ARRAY_CONSTRUCTOR_FUNCTION_NAME, + new io.trino.spi.expression.Call(VARCHAR_ARRAY_TYPE, ARRAY_CONSTRUCTOR_FUNCTION_NAME, List.of( new Variable("varchar_symbol_1", VARCHAR_TYPE), new io.trino.spi.expression.Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), VARCHAR_TYPE)))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java index 34c5cf23b17f..bec2d39494db 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java @@ -20,13 +20,13 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePlanTest; import org.junit.jupiter.api.Test; @@ -36,9 +36,9 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -68,9 +68,9 @@ public void testDereferencePushdownMultiLevel() output(ImmutableList.of("a_msg_x", "a_msg", "b_msg_y"), strictProject( ImmutableMap.of( - "a_msg_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(BIGINT, "a_msg"), new Constant(INTEGER, 1L))), - "a_msg", expression(new SymbolReference(BIGINT, "a_msg")), - "b_msg_y", expression(new SymbolReference(DOUBLE, "b_msg_y"))), + "a_msg_x", expression(new Subscript(BIGINT, new Reference(BIGINT, "a_msg"), new Constant(INTEGER, 1L))), + "a_msg", expression(new Reference(BIGINT, "a_msg")), + "b_msg_y", expression(new Reference(DOUBLE, "b_msg_y"))), join(INNER, builder -> builder .left(values("a_msg")) .right( @@ -87,9 +87,9 @@ public void testDereferencePushdownJoin() "WHERE a.msg.y = b.msg.y", output( project( - ImmutableMap.of("b_x", expression(new SymbolReference(BIGINT, "b_x"))), + ImmutableMap.of("b_x", expression(new Reference(BIGINT, "b_x"))), filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "a_y"), new SymbolReference(DOUBLE, "b_y")), + new Comparison(EQUAL, new Reference(DOUBLE, "a_y"), new Reference(DOUBLE, "b_y")), values( ImmutableList.of("b_x", "b_y", "a_y"), ImmutableList.of(ImmutableList.of( @@ -112,11 +112,11 @@ public void testDereferencePushdownJoin() join(INNER, builder -> builder .left( project(filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "a_y"), new Constant(DOUBLE, 2.0)), + new Comparison(EQUAL, new Reference(DOUBLE, "a_y"), new Constant(DOUBLE, 2.0)), values(ImmutableList.of("a_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0))))))) .right( project(filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "b_y"), new Constant(DOUBLE, 2.0)), + new Comparison(EQUAL, new Reference(DOUBLE, "b_y"), new Constant(DOUBLE, 2.0)), values( ImmutableList.of("b_y", "b_x"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2.0), new Constant(BIGINT, 1L)))))))))); @@ -132,11 +132,11 @@ public void testDereferencePushdownFilter() "WHERE a.msg.x = 7 OR IS_FINITE(b.msg.y)", any( project( - ImmutableMap.of("a_y", expression(new SymbolReference(DOUBLE, "a_y")), "b_x", expression(new SymbolReference(BIGINT, "b_x"))), + ImmutableMap.of("a_y", expression(new Reference(DOUBLE, "a_y")), "b_x", expression(new Reference(BIGINT, "b_x"))), filter( - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a_x"), new Constant(BIGINT, 7L)), - new FunctionCall(IS_FINITE, ImmutableList.of(new SymbolReference(DOUBLE, "b_y"))))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "a_x"), new Constant(BIGINT, 7L)), + new Call(IS_FINITE, ImmutableList.of(new Reference(DOUBLE, "b_y"))))), values( ImmutableList.of("b_x", "b_y", "a_y", "a_x"), ImmutableList.of(ImmutableList.of( @@ -188,11 +188,11 @@ public void testDereferencePushdownWindow() anyTree( project( ImmutableMap.of( - "msg1", expression(new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE), "msg1")), // not pushed down because used in partition by - "msg2", expression(new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE), "msg2")), // not pushed down because used in order by - "msg3", expression(new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE), "msg3")), // not pushed down because used in window function - "msg4_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE), "msg4"), new Constant(INTEGER, 1L))), // pushed down because msg4.x used in window function - "msg5_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE), "msg5"), new Constant(INTEGER, 1L)))), // pushed down because window node does not refer it + "msg1", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg1")), // not pushed down because used in partition by + "msg2", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg2")), // not pushed down because used in order by + "msg3", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg3")), // not pushed down because used in window function + "msg4_x", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg4"), new Constant(INTEGER, 1L))), // pushed down because msg4.x used in window function + "msg5_x", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg5"), new Constant(INTEGER, 1L)))), // pushed down because window node does not refer it values("msg1", "msg2", "msg3", "msg4", "msg5")))); } @@ -210,7 +210,7 @@ public void testDereferencePushdownSemiJoin() anyTree( semiJoin("a_x", "b_z", "semi_join_symbol", project( - ImmutableMap.of("a_y", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE, BIGINT), "msg"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("a_y", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(BIGINT, DOUBLE, BIGINT), "msg"), new Constant(INTEGER, 2L)))), values(ImmutableList.of("msg", "a_x"), ImmutableList.of())), values(ImmutableList.of("b_z"), ImmutableList.of())))); } @@ -221,9 +221,9 @@ public void testDereferencePushdownLimit() assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))), ROW(CAST(ROW(3, 4.0) AS ROW(x BIGINT, y DOUBLE))))" + "SELECT msg.x * 3 FROM t limit 1", anyTree( - strictProject(ImmutableMap.of("x_into_3", expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "msg_x"), new Constant(BIGINT, 3L)))), + strictProject(ImmutableMap.of("x_into_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "msg_x"), new Constant(BIGINT, 3L)))), limit(1, - strictProject(ImmutableMap.of("msg_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, DOUBLE), "msg"), new Constant(INTEGER, 1L)))), + strictProject(ImmutableMap.of("msg_x", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg"), new Constant(INTEGER, 1L)))), values("msg")))))); // dereference pushdown + constant folding @@ -236,9 +236,9 @@ public void testDereferencePushdownLimit() limit( 100, project( - ImmutableMap.of("b_x", expression(new SymbolReference(BIGINT, "b_x"))), + ImmutableMap.of("b_x", expression(new Reference(BIGINT, "b_x"))), filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "a_y"), new SymbolReference(DOUBLE, "b_y")), + new Comparison(EQUAL, new Reference(DOUBLE, "a_y"), new Reference(DOUBLE, "b_y")), values( ImmutableList.of("b_x", "b_y", "a_y"), ImmutableList.of(ImmutableList.of( @@ -262,11 +262,11 @@ public void testDereferencePushdownLimit() join(INNER, builder -> builder .left( project(filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "a_y"), new Constant(DOUBLE, 2.0)), + new Comparison(EQUAL, new Reference(DOUBLE, "a_y"), new Constant(DOUBLE, 2.0)), values(ImmutableList.of("a_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0))))))) .right( project(filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "b_y"), new Constant(DOUBLE, 2.0)), + new Comparison(EQUAL, new Reference(DOUBLE, "b_y"), new Constant(DOUBLE, 2.0)), values( ImmutableList.of("b_y", "b_x"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2.0), new Constant(BIGINT, 1L)))))))))); @@ -281,18 +281,18 @@ public void testDereferencePushdownUnnest() "CROSS JOIN UNNEST (a.array) " + "WHERE a.msg.x + b.msg.x < BIGINT '10'", output(ImmutableList.of("expr"), - strictProject(ImmutableMap.of("expr", expression(new SymbolReference(BIGINT, "a_x"))), + strictProject(ImmutableMap.of("expr", expression(new Reference(BIGINT, "a_x"))), unnest( join(INNER, builder -> builder .left( project( filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "a_y"), new Constant(DOUBLE, 2.0)), + new Comparison(EQUAL, new Reference(DOUBLE, "a_y"), new Constant(DOUBLE, 2.0)), values("array", "a_y", "a_x")))) .right( project( filter( - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "b_y"), new Constant(DOUBLE, 2.0)), + new Comparison(EQUAL, new Reference(DOUBLE, "b_y"), new Constant(DOUBLE, 2.0)), values(ImmutableList.of("b_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0)))))))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java index 72d878ff19f4..2c1f856ed31c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java @@ -17,12 +17,12 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import io.trino.type.FunctionType; import io.trino.type.UnknownType; import org.junit.jupiter.api.Test; @@ -33,7 +33,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static org.assertj.core.api.Assertions.assertThat; public class TestDeterminismEvaluator @@ -48,34 +48,34 @@ public void testSanity() assertThat(DeterminismEvaluator.isDeterministic(function("shuffle", ImmutableList.of(new ArrayType(VARCHAR)), ImmutableList.of(new Constant(UnknownType.UNKNOWN, null))) )).isFalse(); assertThat(DeterminismEvaluator.isDeterministic(function("uuid"))).isFalse(); - assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(new SymbolReference(DOUBLE, "symbol"))))).isTrue(); + assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(new Reference(DOUBLE, "symbol"))))).isTrue(); assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(function("rand"))))).isFalse(); assertThat(DeterminismEvaluator.isDeterministic( function( "abs", ImmutableList.of(DOUBLE), - ImmutableList.of(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(new SymbolReference(DOUBLE, "symbol"))))) + ImmutableList.of(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(new Reference(DOUBLE, "symbol"))))) )).isTrue(); assertThat(DeterminismEvaluator.isDeterministic( function( "filter", ImmutableList.of(new ArrayType(INTEGER), new FunctionType(ImmutableList.of(INTEGER), BOOLEAN)), - ImmutableList.of(lambda(new Symbol(INTEGER, "a"), comparison(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))))) + ImmutableList.of(lambda(new Symbol(INTEGER, "a"), comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))))) )).isTrue(); assertThat(DeterminismEvaluator.isDeterministic( function( "filter", ImmutableList.of(new ArrayType(INTEGER), new FunctionType(ImmutableList.of(INTEGER), BOOLEAN)), - ImmutableList.of(lambda(new Symbol(INTEGER, "a"), comparison(GREATER_THAN, function("rand", ImmutableList.of(INTEGER), ImmutableList.of(new SymbolReference(INTEGER, "a"))), new Constant(INTEGER, 0L))))) + ImmutableList.of(lambda(new Symbol(INTEGER, "a"), comparison(GREATER_THAN, function("rand", ImmutableList.of(INTEGER), ImmutableList.of(new Reference(INTEGER, "a"))), new Constant(INTEGER, 0L))))) )).isFalse(); } - private FunctionCall function(String name) + private Call function(String name) { return function(name, ImmutableList.of(), ImmutableList.of()); } - private FunctionCall function(String name, List types, List arguments) + private Call function(String name, List types, List arguments) { return functionResolution .functionCallBuilder(name) @@ -83,13 +83,13 @@ private FunctionCall function(String name, List types, List ar .build(); } - private static ComparisonExpression comparison(ComparisonExpression.Operator operator, Expression left, Expression right) + private static Comparison comparison(Comparison.Operator operator, Expression left, Expression right) { - return new ComparisonExpression(operator, left, right); + return new Comparison(operator, left, right); } - private static LambdaExpression lambda(Symbol symbol, Expression body) + private static Lambda lambda(Symbol symbol, Expression body) { - return new LambdaExpression(ImmutableList.of(symbol), body); + return new Lambda(ImmutableList.of(symbol), body); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java index 83bb7985a949..0a3efff233de 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java @@ -26,15 +26,15 @@ import io.trino.spi.type.DoubleType; import io.trino.spi.type.RealType; import io.trino.spi.type.Type; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; import io.trino.sql.planner.DomainTranslator.ExtractionResult; import io.trino.type.LikePattern; import io.trino.type.LikePatternType; @@ -70,15 +70,15 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.or; import static io.trino.testing.TestingConnectorSession.SESSION; @@ -152,7 +152,7 @@ public void testNoneRoundTrip() { TupleDomain tupleDomain = TupleDomain.none(); ExtractionResult result = fromPredicate(toPredicate(tupleDomain)); - assertThat(result.getRemainingExpression()).isEqualTo(TRUE_LITERAL); + assertThat(result.getRemainingExpression()).isEqualTo(TRUE); assertThat(result.getTupleDomain()).isEqualTo(tupleDomain); } @@ -161,7 +161,7 @@ public void testAllRoundTrip() { TupleDomain tupleDomain = TupleDomain.all(); ExtractionResult result = fromPredicate(toPredicate(tupleDomain)); - assertThat(result.getRemainingExpression()).isEqualTo(TRUE_LITERAL); + assertThat(result.getRemainingExpression()).isEqualTo(TRUE); assertThat(result.getTupleDomain()).isEqualTo(tupleDomain); } @@ -247,7 +247,7 @@ public void testToPredicateNone() .put(C_BOOLEAN, Domain.none(BOOLEAN)) .buildOrThrow()); - assertThat(toPredicate(tupleDomain)).isEqualTo(FALSE_LITERAL); + assertThat(toPredicate(tupleDomain)).isEqualTo(FALSE); } @Test @@ -261,7 +261,7 @@ public void testToPredicateAllIgnored() .buildOrThrow()); ExtractionResult result = fromPredicate(toPredicate(tupleDomain)); - assertThat(result.getRemainingExpression()).isEqualTo(TRUE_LITERAL); + assertThat(result.getRemainingExpression()).isEqualTo(TRUE); assertThat(result.getTupleDomain()).isEqualTo(tupleDomain(ImmutableMap.builder() .put(C_BIGINT, Domain.singleValue(BIGINT, 1L)) .put(C_DOUBLE, Domain.onlyNull(DOUBLE)) @@ -281,10 +281,10 @@ public void testToPredicate() assertThat(toPredicate(tupleDomain)).isEqualTo(isNull(C_BIGINT)); tupleDomain = tupleDomain(C_BIGINT, Domain.none(BIGINT)); - assertThat(toPredicate(tupleDomain)).isEqualTo(FALSE_LITERAL); + assertThat(toPredicate(tupleDomain)).isEqualTo(FALSE); tupleDomain = tupleDomain(C_BIGINT, Domain.all(BIGINT)); - assertThat(toPredicate(tupleDomain)).isEqualTo(TRUE_LITERAL); + assertThat(toPredicate(tupleDomain)).isEqualTo(TRUE); tupleDomain = tupleDomain(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 1L)), false)); assertThat(toPredicate(tupleDomain)).isEqualTo(greaterThan(C_BIGINT, bigintLiteral(1L))); @@ -524,18 +524,18 @@ public void testFromSingleBooleanReference() Expression originalPredicate = C_BOOLEAN.toSymbolReference(); ExtractionResult result = fromPredicate(originalPredicate); assertThat(result.getTupleDomain()).isEqualTo(tupleDomain(C_BOOLEAN, Domain.create(ValueSet.ofRanges(Range.equal(BOOLEAN, true)), false))); - assertThat(result.getRemainingExpression()).isEqualTo(TRUE_LITERAL); + assertThat(result.getRemainingExpression()).isEqualTo(TRUE); originalPredicate = not(C_BOOLEAN.toSymbolReference()); result = fromPredicate(originalPredicate); assertThat(result.getTupleDomain()).isEqualTo(tupleDomain(C_BOOLEAN, Domain.create(ValueSet.ofRanges(Range.equal(BOOLEAN, true)).complement(), false))); - assertThat(result.getRemainingExpression()).isEqualTo(TRUE_LITERAL); + assertThat(result.getRemainingExpression()).isEqualTo(TRUE); originalPredicate = and(C_BOOLEAN.toSymbolReference(), C_BOOLEAN_1.toSymbolReference()); result = fromPredicate(originalPredicate); Domain domain = Domain.create(ValueSet.ofRanges(Range.equal(BOOLEAN, true)), false); assertThat(result.getTupleDomain()).isEqualTo(tupleDomain(C_BOOLEAN, domain, C_BOOLEAN_1, domain)); - assertThat(result.getRemainingExpression()).isEqualTo(TRUE_LITERAL); + assertThat(result.getRemainingExpression()).isEqualTo(TRUE); originalPredicate = or(C_BOOLEAN.toSymbolReference(), C_BOOLEAN_1.toSymbolReference()); result = fromPredicate(originalPredicate); @@ -554,7 +554,7 @@ public void testFromNotPredicate() assertUnsupportedPredicate(not(and(equal(C_BIGINT, bigintLiteral(1L)), unprocessableExpression1(C_BIGINT)))); assertUnsupportedPredicate(not(unprocessableExpression1(C_BIGINT))); - assertPredicateIsAlwaysFalse(not(TRUE_LITERAL)); + assertPredicateIsAlwaysFalse(not(TRUE)); assertPredicateTranslates( not(equal(C_BIGINT, bigintLiteral(1L))), @@ -943,12 +943,12 @@ public void testPredicateWithVarcharCastToDate() @Test public void testFromUnprocessableInPredicate() { - assertUnsupportedPredicate(new InPredicate(unprocessableExpression1(C_BIGINT), ImmutableList.of(TRUE_LITERAL))); - assertUnsupportedPredicate(new InPredicate(C_BOOLEAN.toSymbolReference(), ImmutableList.of(unprocessableExpression1(C_BOOLEAN)))); + assertUnsupportedPredicate(new In(unprocessableExpression1(C_BIGINT), ImmutableList.of(TRUE))); + assertUnsupportedPredicate(new In(C_BOOLEAN.toSymbolReference(), ImmutableList.of(unprocessableExpression1(C_BOOLEAN)))); assertUnsupportedPredicate( - new InPredicate(C_BOOLEAN.toSymbolReference(), ImmutableList.of(TRUE_LITERAL, unprocessableExpression1(C_BOOLEAN)))); + new In(C_BOOLEAN.toSymbolReference(), ImmutableList.of(TRUE, unprocessableExpression1(C_BOOLEAN)))); assertPredicateTranslates( - not(new InPredicate(C_BOOLEAN.toSymbolReference(), ImmutableList.of(unprocessableExpression1(C_BOOLEAN)))), + not(new In(C_BOOLEAN.toSymbolReference(), ImmutableList.of(unprocessableExpression1(C_BOOLEAN)))), tupleDomain(C_BOOLEAN, Domain.notNull(BOOLEAN)), not(equal(C_BOOLEAN, unprocessableExpression1(C_BOOLEAN)))); } @@ -1064,7 +1064,7 @@ private void testInPredicate(Symbol symbol, Symbol symbol2, Type type, Object on assertPredicateTranslates( not(in(symbol, List.of(oneExpression, nullExpression, twoExpression))), TupleDomain.none(), - TRUE_LITERAL); + TRUE); // NOT IN, with expression assertPredicateTranslates( @@ -1206,11 +1206,11 @@ public void testInPredicateWithEquitableType() @Test public void testFromInPredicateWithCastsAndNulls() { - assertPredicateIsAlwaysFalse(new InPredicate( + assertPredicateIsAlwaysFalse(new In( C_BIGINT.toSymbolReference(), ImmutableList.of(new Constant(BIGINT, null)))); - assertUnsupportedPredicate(not(new InPredicate( + assertUnsupportedPredicate(not(new In( cast(C_SMALLINT, BIGINT), ImmutableList.of(new Constant(BIGINT, null))))); } @@ -1277,10 +1277,10 @@ public void testFromIsNotNullPredicate() @Test public void testFromBooleanLiteralPredicate() { - assertPredicateIsAlwaysTrue(TRUE_LITERAL); - assertPredicateIsAlwaysFalse(not(TRUE_LITERAL)); - assertPredicateIsAlwaysFalse(FALSE_LITERAL); - assertPredicateIsAlwaysTrue(not(FALSE_LITERAL)); + assertPredicateIsAlwaysTrue(TRUE); + assertPredicateIsAlwaysFalse(not(TRUE)); + assertPredicateIsAlwaysFalse(FALSE); + assertPredicateIsAlwaysTrue(not(FALSE)); } @Test @@ -1475,10 +1475,10 @@ public void testStartsWithFunction() @Test public void testUnsupportedFunctions() { - assertUnsupportedPredicate(new FunctionCall( + assertUnsupportedPredicate(new Call( functionResolution.resolveFunction("length", fromTypes(VARCHAR)), ImmutableList.of(C_VARCHAR.toSymbolReference()))); - assertUnsupportedPredicate(new FunctionCall( + assertUnsupportedPredicate(new Call( functionResolution.resolveFunction("replace", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(C_VARCHAR.toSymbolReference(), stringLiteral("abc")))); } @@ -1497,12 +1497,12 @@ public void testCharComparedToVarcharExpression() private void assertPredicateIsAlwaysTrue(Expression expression) { - assertPredicateTranslates(expression, TupleDomain.all(), TRUE_LITERAL); + assertPredicateTranslates(expression, TupleDomain.all(), TRUE); } private void assertPredicateIsAlwaysFalse(Expression expression) { - assertPredicateTranslates(expression, TupleDomain.none(), TRUE_LITERAL); + assertPredicateTranslates(expression, TupleDomain.none(), TRUE); } private void assertUnsupportedPredicate(Expression expression) @@ -1512,7 +1512,7 @@ private void assertUnsupportedPredicate(Expression expression) private void assertPredicateTranslates(Expression expression, TupleDomain tupleDomain) { - assertPredicateTranslates(expression, tupleDomain, TRUE_LITERAL); + assertPredicateTranslates(expression, tupleDomain, TRUE); } private void assertPredicateDerives(Expression expression, TupleDomain tupleDomain) @@ -1531,7 +1531,7 @@ private void assertNoFullPushdown(Expression expression) { ExtractionResult result = fromPredicate(expression); assertThat(result.getRemainingExpression()) - .isNotEqualTo(TRUE_LITERAL); + .isNotEqualTo(TRUE); } private ExtractionResult fromPredicate(Expression originalPredicate) @@ -1556,74 +1556,74 @@ private static Expression unprocessableExpression2(Symbol symbol) private Expression randPredicate(Symbol symbol, Type type) { - FunctionCall rand = functionResolution + Call rand = functionResolution .functionCallBuilder("rand") .build(); return comparison(GREATER_THAN, symbol.toSymbolReference(), cast(rand, type)); } - private static ComparisonExpression equal(Symbol symbol, Expression expression) + private static Comparison equal(Symbol symbol, Expression expression) { return equal(symbol.toSymbolReference(), expression); } - private static ComparisonExpression notEqual(Symbol symbol, Expression expression) + private static Comparison notEqual(Symbol symbol, Expression expression) { return notEqual(symbol.toSymbolReference(), expression); } - private static ComparisonExpression greaterThan(Symbol symbol, Expression expression) + private static Comparison greaterThan(Symbol symbol, Expression expression) { return greaterThan(symbol.toSymbolReference(), expression); } - private static ComparisonExpression greaterThanOrEqual(Symbol symbol, Expression expression) + private static Comparison greaterThanOrEqual(Symbol symbol, Expression expression) { return greaterThanOrEqual(symbol.toSymbolReference(), expression); } - private static ComparisonExpression lessThan(Symbol symbol, Expression expression) + private static Comparison lessThan(Symbol symbol, Expression expression) { return lessThan(symbol.toSymbolReference(), expression); } - private static ComparisonExpression lessThanOrEqual(Symbol symbol, Expression expression) + private static Comparison lessThanOrEqual(Symbol symbol, Expression expression) { return lessThanOrEqual(symbol.toSymbolReference(), expression); } - private static ComparisonExpression isDistinctFrom(Symbol symbol, Expression expression) + private static Comparison isDistinctFrom(Symbol symbol, Expression expression) { return isDistinctFrom(symbol.toSymbolReference(), expression); } - private FunctionCall like(Symbol symbol, String pattern) + private Call like(Symbol symbol, String pattern) { - return new FunctionCall( + return new Call( functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)), ImmutableList.of(symbol.toSymbolReference(), new Constant(LikePatternType.LIKE_PATTERN, LikePattern.compile(pattern, Optional.empty())))); } - private FunctionCall like(Symbol symbol, Expression pattern, Expression escape) + private Call like(Symbol symbol, Expression pattern, Expression escape) { - FunctionCall likePattern = new FunctionCall( + Call likePattern = new Call( functionResolution.resolveFunction(LIKE_PATTERN_FUNCTION_NAME, fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(symbol.toSymbolReference(), pattern, escape)); - return new FunctionCall( + return new Call( functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)), ImmutableList.of(symbol.toSymbolReference(), pattern, likePattern)); } - private FunctionCall like(Symbol symbol, String pattern, Character escape) + private Call like(Symbol symbol, String pattern, Character escape) { - return new FunctionCall( + return new Call( functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)), ImmutableList.of(symbol.toSymbolReference(), new Constant(LikePatternType.LIKE_PATTERN, LikePattern.compile(pattern, Optional.of(escape))))); } - private FunctionCall startsWith(Symbol symbol, Expression expression) + private Call startsWith(Symbol symbol, Expression expression) { - return new FunctionCall( + return new Call( functionResolution.resolveFunction("starts_with", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(symbol.toSymbolReference(), expression)); } @@ -1633,29 +1633,29 @@ private static Expression isNotNull(Symbol symbol) return isNotNull(symbol.toSymbolReference()); } - private static IsNullPredicate isNull(Symbol symbol) + private static IsNull isNull(Symbol symbol) { - return new IsNullPredicate(symbol.toSymbolReference()); + return new IsNull(symbol.toSymbolReference()); } - private InPredicate in(Symbol symbol, List values) + private In in(Symbol symbol, List values) { return in(symbol.toSymbolReference(), symbol.getType(), values); } - private static BetweenPredicate between(Symbol symbol, Expression min, Expression max) + private static Between between(Symbol symbol, Expression min, Expression max) { - return new BetweenPredicate(symbol.toSymbolReference(), min, max); + return new Between(symbol.toSymbolReference(), min, max); } private static Expression isNotNull(Expression expression) { - return new NotExpression(new IsNullPredicate(expression)); + return new Not(new IsNull(expression)); } - private InPredicate in(Expression expression, Type type, List values) + private In in(Expression expression, Type type, List values) { - return new InPredicate( + return new In( expression, values.stream() .map(value -> value instanceof Expression valueExpression ? @@ -1664,54 +1664,54 @@ private InPredicate in(Expression expression, Type type, List values) .collect(toImmutableList())); } - private static BetweenPredicate between(Expression expression, Expression min, Expression max) + private static Between between(Expression expression, Expression min, Expression max) { - return new BetweenPredicate(expression, min, max); + return new Between(expression, min, max); } - private static ComparisonExpression equal(Expression left, Expression right) + private static Comparison equal(Expression left, Expression right) { return comparison(EQUAL, left, right); } - private static ComparisonExpression notEqual(Expression left, Expression right) + private static Comparison notEqual(Expression left, Expression right) { return comparison(NOT_EQUAL, left, right); } - private static ComparisonExpression greaterThan(Expression left, Expression right) + private static Comparison greaterThan(Expression left, Expression right) { return comparison(GREATER_THAN, left, right); } - private static ComparisonExpression greaterThanOrEqual(Expression left, Expression right) + private static Comparison greaterThanOrEqual(Expression left, Expression right) { return comparison(GREATER_THAN_OR_EQUAL, left, right); } - private static ComparisonExpression lessThan(Expression left, Expression expression) + private static Comparison lessThan(Expression left, Expression expression) { return comparison(LESS_THAN, left, expression); } - private static ComparisonExpression lessThanOrEqual(Expression left, Expression right) + private static Comparison lessThanOrEqual(Expression left, Expression right) { return comparison(LESS_THAN_OR_EQUAL, left, right); } - private static ComparisonExpression isDistinctFrom(Expression left, Expression right) + private static Comparison isDistinctFrom(Expression left, Expression right) { return comparison(IS_DISTINCT_FROM, left, right); } - private static NotExpression not(Expression expression) + private static Not not(Expression expression) { - return new NotExpression(expression); + return new Not(expression); } - private static ComparisonExpression comparison(ComparisonExpression.Operator operator, Expression expression1, Expression expression2) + private static Comparison comparison(Comparison.Operator operator, Expression expression1, Expression expression2) { - return new ComparisonExpression(operator, expression1, expression2); + return new Comparison(operator, expression1, expression2); } private static Constant bigintLiteral(long value) @@ -1761,7 +1761,7 @@ private void testSimpleComparison(Expression expression, Symbol symbol, Range ex private void testSimpleComparison(Expression expression, Symbol symbol, Domain expectedDomain) { - testSimpleComparison(expression, symbol, TRUE_LITERAL, expectedDomain); + testSimpleComparison(expression, symbol, TRUE, expectedDomain); } private void testSimpleComparison(Expression expression, Symbol symbol, Expression expectedRemainingExpression, Domain expectedDomain) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java index b2c5c4572c28..97971cb56c64 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java @@ -23,16 +23,16 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; @@ -55,16 +55,16 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.DynamicFilterPattern; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyNot; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -158,7 +158,7 @@ public void testRightEquiJoinWithLeftExpression() .right( anyTree( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))))); } @@ -170,13 +170,13 @@ public void testRightNonEquiJoin() join(RIGHT, builder -> builder .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( exchange( tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))) - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "LINEITEM_OK"), new SymbolReference(BIGINT, "ORDERS_OK"))) - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "ORDERS_OK"), GREATER_THAN, "LINEITEM_OK")))))); + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "LINEITEM_OK"), new Reference(BIGINT, "ORDERS_OK"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "ORDERS_OK"), GREATER_THAN, "LINEITEM_OK")))))); } @Test @@ -196,12 +196,12 @@ public void testCrossJoinInequalityDF() { assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey > l.orderkey", anyTree(filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), GREATER_THAN, "L_ORDERKEY"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), GREATER_THAN, "L_ORDERKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( exchange( @@ -213,12 +213,12 @@ public void testCrossJoinInequalityDFWithConditionReversed() { assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey < o.orderkey", anyTree(filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), GREATER_THAN, "L_ORDERKEY"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), GREATER_THAN, "L_ORDERKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( exchange( @@ -230,15 +230,15 @@ public void testCrossJoinBetweenDF() { assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey BETWEEN l.orderkey AND l.partkey", anyTree(filter( - new BetweenPredicate(new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"), new SymbolReference(BIGINT, "L_PARTKEY")), + new Between(new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"), new Reference(BIGINT, "L_PARTKEY")), join(INNER, builder -> builder .dynamicFilter( ImmutableList.of( - new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), GREATER_THAN_OR_EQUAL, "L_ORDERKEY"), - new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), LESS_THAN_OR_EQUAL, "L_PARTKEY"))) + new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), GREATER_THAN_OR_EQUAL, "L_ORDERKEY"), + new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), LESS_THAN_OR_EQUAL, "L_PARTKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( exchange( @@ -246,12 +246,12 @@ public void testCrossJoinBetweenDF() assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey BETWEEN l.orderkey AND l.partkey - 1", anyTree(filter( - new BetweenPredicate(new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"), new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L))), + new Between(new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"), new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 1L))), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), GREATER_THAN_OR_EQUAL, "L_ORDERKEY"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), GREATER_THAN_OR_EQUAL, "L_ORDERKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( exchange( @@ -259,13 +259,13 @@ public void testCrossJoinBetweenDF() assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey BETWEEN l.orderkey + 1 AND l.partkey", anyTree(filter( - new BetweenPredicate(new SymbolReference(BIGINT, "O_ORDERKEY"), new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "L_ORDERKEY"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "L_PARTKEY")), + new Between(new Reference(BIGINT, "O_ORDERKEY"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_ORDERKEY"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "L_PARTKEY")), join(INNER, builder -> builder .dynamicFilter(ImmutableList.of( - new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), LESS_THAN_OR_EQUAL, "L_PARTKEY"))) + new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), LESS_THAN_OR_EQUAL, "L_PARTKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( exchange( @@ -277,12 +277,12 @@ public void testCrossJoinInequalityWithCastOnTheLeft() { assertPlan("SELECT o.comment, l.comment FROM lineitem l, orders o WHERE o.comment < l.comment", anyTree(filter( - new ComparisonExpression(GREATER_THAN, new Cast(new SymbolReference(VARCHAR, "L_COMMENT"), createVarcharType(79)), new SymbolReference(createVarcharType(79), "O_COMMENT")), + new Comparison(GREATER_THAN, new Cast(new Reference(VARCHAR, "L_COMMENT"), createVarcharType(79)), new Reference(createVarcharType(79), "O_COMMENT")), join(INNER, builder -> builder .dynamicFilter(ImmutableList.of( new DynamicFilterPattern(typeOnlyCast("L_COMMENT", createVarcharType(79)), GREATER_THAN, "O_COMMENT", false))) .left( - filter(TRUE_LITERAL, + filter(TRUE, tableScan("lineitem", ImmutableMap.of("L_COMMENT", "comment")))) .right( exchange( @@ -303,17 +303,17 @@ public void testCrossJoinInequalityWithCastOnTheRight() anyTree( project( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(VARCHAR, "O_COMMENT"), new SymbolReference(VARCHAR, "expr")), + new Comparison(LESS_THAN, new Reference(VARCHAR, "O_COMMENT"), new Reference(VARCHAR, "expr")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(VARCHAR, "O_COMMENT"), LESS_THAN, "expr"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(VARCHAR, "O_COMMENT"), LESS_THAN, "expr"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_COMMENT", "comment")))) .right( anyTree( project( - ImmutableMap.of("expr", expression(new Cast(new SymbolReference(VARCHAR, "L_COMMENT"), createVarcharType(79)))), + ImmutableMap.of("expr", expression(new Cast(new Reference(VARCHAR, "L_COMMENT"), createVarcharType(79)))), tableScan("lineitem", ImmutableMap.of("L_COMMENT", "comment")))))))))); } @@ -355,12 +355,12 @@ public void testIsNotDistinctFromJoin() { assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey IS NOT DISTINCT FROM o.orderkey", anyTree(filter( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"))), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), EQUAL, "L_ORDERKEY", true))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), EQUAL, "L_ORDERKEY", true))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( exchange( @@ -369,7 +369,7 @@ public void testIsNotDistinctFromJoin() // Dynamic filter is not supported for IS DISTINCT FROM assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey IS DISTINCT FROM o.orderkey", anyTree(filter( - new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), join(INNER, builder -> builder .left( tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey"))) @@ -380,7 +380,7 @@ public void testIsNotDistinctFromJoin() // extendedprice and totalprice are of DOUBLE type, dynamic filter is not supported with IS NOT DISTINCT FROM clause on DOUBLE or REAL types assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE l.extendedprice IS NOT DISTINCT FROM o.totalprice", anyTree(filter( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DOUBLE, "O_TOTALPRICE"), new SymbolReference(DOUBLE, "L_EXTENDEDPRICE"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DOUBLE, "O_TOTALPRICE"), new Reference(DOUBLE, "L_EXTENDEDPRICE"))), join(INNER, builder -> builder .left( tableScan("orders", ImmutableMap.of("O_TOTALPRICE", "totalprice"))) @@ -436,12 +436,12 @@ public void testJoinOnCast() .equiCriteria("expr_orders", "expr_lineitem") .left( project( - ImmutableMap.of("expr_orders", expression(new Cast(new SymbolReference(BIGINT, "ORDERS_OK"), INTEGER))), + ImmutableMap.of("expr_orders", expression(new Cast(new Reference(BIGINT, "ORDERS_OK"), INTEGER))), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( anyTree( project( - ImmutableMap.of("expr_lineitem", expression(new Cast(new SymbolReference(BIGINT, "LINEITEM_OK"), INTEGER))), + ImmutableMap.of("expr_lineitem", expression(new Cast(new Reference(BIGINT, "LINEITEM_OK"), INTEGER))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))))); // Dynamic filter is removed due to double cast on orders.orderkey @@ -451,7 +451,7 @@ public void testJoinOnCast() .equiCriteria("expr_orders", "LINEITEM_OK") .left( project( - ImmutableMap.of("expr_orders", expression(new Cast(new Cast(new SymbolReference(BIGINT, "ORDERS_OK"), INTEGER), BIGINT))), + ImmutableMap.of("expr_orders", expression(new Cast(new Cast(new Reference(BIGINT, "ORDERS_OK"), INTEGER), BIGINT))), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( anyTree( @@ -468,7 +468,7 @@ public void testJoinImplicitCoercions() .equiCriteria("expr_linenumber", "ORDERS_OK") .left( project( - ImmutableMap.of("expr_linenumber", expression(new Cast(new SymbolReference(INTEGER, "LINEITEM_LN"), BIGINT))), + ImmutableMap.of("expr_linenumber", expression(new Cast(new Reference(INTEGER, "LINEITEM_LN"), BIGINT))), node(FilterNode.class, tableScan("lineitem", ImmutableMap.of("LINEITEM_LN", "linenumber"))) .with(numberOfDynamicFilters(1)))) @@ -487,8 +487,8 @@ public void testJoinMultipleEquiJoinClauses() equiJoinClause("ORDERS_OK", "LINEITEM_OK"), equiJoinClause("ORDERS_CK", "LINEITEM_PK"))) .dynamicFilter(ImmutableMap.of( - new SymbolReference(BIGINT, "ORDERS_OK"), "LINEITEM_OK", - new SymbolReference(BIGINT, "ORDERS_CK"), "LINEITEM_PK")) + new Reference(BIGINT, "ORDERS_OK"), "LINEITEM_OK", + new Reference(BIGINT, "ORDERS_CK"), "LINEITEM_PK")) .left( anyTree( tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey", "ORDERS_CK", "custkey")))) @@ -540,10 +540,10 @@ public void testInnerInequalityJoinWithEquiJoinConjuncts() FilterNode.class, join(INNER, builder -> builder .equiCriteria("O_SHIPPRIORITY", "L_LINENUMBER") - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"))) .dynamicFilter(ImmutableList.of( - new DynamicFilterPattern(new SymbolReference(createVarcharType(1), "O_SHIPPRIORITY"), EQUAL, "L_LINENUMBER"), - new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), LESS_THAN, "L_ORDERKEY"))) + new DynamicFilterPattern(new Reference(createVarcharType(1), "O_SHIPPRIORITY"), EQUAL, "L_LINENUMBER"), + new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), LESS_THAN, "L_ORDERKEY"))) .left( anyTree(tableScan("orders", ImmutableMap.of( "O_SHIPPRIORITY", "shippriority", @@ -587,8 +587,8 @@ public void testSubTreeJoinDFOnBuildSide() join(INNER, builder -> builder .equiCriteria("LINEITEM_OK", "PART_PK") .dynamicFilter(ImmutableMap.of( - new SymbolReference(BIGINT, "LINEITEM_OK"), "PART_PK", - new SymbolReference(BIGINT, "ORDERS_OK"), "PART_PK")) + new Reference(BIGINT, "LINEITEM_OK"), "PART_PK", + new Reference(BIGINT, "ORDERS_OK"), "PART_PK")) .left( join(INNER, leftJoinBuilder -> leftJoinBuilder .equiCriteria("LINEITEM_OK", "ORDERS_OK") @@ -654,7 +654,7 @@ public void testNonPushedDownJoinFilterRemoval() .equiCriteria(ImmutableList.of(equiJoinClause("K0", "K2"), equiJoinClause("S", "V2"))) .left( project( - ImmutableMap.of("S", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "V0"), new SymbolReference(BIGINT, "V1")))), + ImmutableMap.of("S", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "V0"), new Reference(BIGINT, "V1")))), join(INNER, leftJoinBuilder -> leftJoinBuilder .equiCriteria("K0", "K1") .dynamicFilter(BIGINT, "K0", "K1") @@ -682,7 +682,7 @@ public void testSemiJoin() noSemiJoinRewrite(), anyTree( filter( - new SymbolReference(BOOLEAN, "S"), + new Reference(BOOLEAN, "S"), semiJoin("X", "Y", "S", true, anyTree( tableScan("orders", ImmutableMap.of("X", "orderkey"))), @@ -698,7 +698,7 @@ public void testNonFilteringSemiJoin() "SELECT * FROM orders WHERE orderkey NOT IN (SELECT orderkey FROM lineitem WHERE linenumber < 0)", anyTree( filter( - new NotExpression(new SymbolReference(BOOLEAN, "S")), + new Not(new Reference(BOOLEAN, "S")), semiJoin("X", "Y", "S", false, tableScan("orders", ImmutableMap.of("X", "orderkey")), anyTree( @@ -721,10 +721,10 @@ public void testSemiJoinWithStaticFiltering() noSemiJoinRewrite(), anyTree( filter( - new SymbolReference(BOOLEAN, "S"), + new Reference(BOOLEAN, "S"), semiJoin("X", "Y", "S", true, filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "X"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "X"), new Constant(BIGINT, 0L)), tableScan("orders", ImmutableMap.of("X", "orderkey"))), anyTree( tableScan("lineitem", ImmutableMap.of("Y", "orderkey"))))))); @@ -739,13 +739,13 @@ public void testMultiSemiJoin() noSemiJoinRewrite(), anyTree( filter( - new SymbolReference(BOOLEAN, "S0"), + new Reference(BOOLEAN, "S0"), semiJoin("PART_PK", "LINEITEM_PK", "S0", true, anyTree( tableScan("part", ImmutableMap.of("PART_PK", "partkey"))), anyTree( filter( - new SymbolReference(BOOLEAN, "S1"), + new Reference(BOOLEAN, "S1"), project( semiJoin("LINEITEM_OK", "ORDERS_OK", "S1", true, anyTree( @@ -764,9 +764,9 @@ public void testSemiJoinUnsupportedDynamicFilterRemoval() noSemiJoinRewrite(), anyTree( filter( - new SymbolReference(BOOLEAN, "S0"), + new Reference(BOOLEAN, "S0"), semiJoin("LINEITEM_PK_PLUS_1000", "PART_PK", "S0", false, - project(ImmutableMap.of("LINEITEM_PK_PLUS_1000", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "LINEITEM_PK"), new Constant(BIGINT, 1000L)))), + project(ImmutableMap.of("LINEITEM_PK_PLUS_1000", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_PK"), new Constant(BIGINT, 1000L)))), tableScan("lineitem", ImmutableMap.of("LINEITEM_PK", "partkey"))), anyTree( tableScan("part", ImmutableMap.of("PART_PK", "partkey"))))))); @@ -778,11 +778,11 @@ public void testExpressionPushedDownToLeftJoinSourceWhenUsingOn() assertPlan("SELECT o.orderkey FROM orders o JOIN lineitem l ON o.orderkey + 1 < l.orderkey", anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "expr"), new SymbolReference(BIGINT, "LINEITEM_OK")), + new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "LINEITEM_OK")), join(INNER, builder -> builder .left( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L)))), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( anyTree( @@ -795,16 +795,16 @@ public void testExpressionPushedDownToRightJoinSourceWhenUsingOn() assertPlan("SELECT o.orderkey FROM orders o JOIN lineitem l ON o.orderkey < l.orderkey + 1", anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ORDERS_OK"), new SymbolReference(BIGINT, "expr")), + new Comparison(LESS_THAN, new Reference(BIGINT, "ORDERS_OK"), new Reference(BIGINT, "expr")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "ORDERS_OK"), LESS_THAN, "expr"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "ORDERS_OK"), LESS_THAN, "expr"))) .left( - filter(TRUE_LITERAL, + filter(TRUE, tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( anyTree( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))))); } @@ -813,7 +813,7 @@ public void testExpressionNotPushedDownToLeftJoinSource() { assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey + 1 < l.orderkey", anyTree(filter( - new ComparisonExpression(LESS_THAN, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "LINEITEM_OK")), + new Comparison(LESS_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "ORDERS_OK"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "LINEITEM_OK")), join(INNER, builder -> builder .left(tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))) .right(exchange( @@ -826,17 +826,17 @@ public void testExpressionPushedDownToRightJoinSource() assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE o.orderkey < l.orderkey + 1", anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ORDERS_OK"), new SymbolReference(BIGINT, "expr")), + new Comparison(LESS_THAN, new Reference(BIGINT, "ORDERS_OK"), new Reference(BIGINT, "expr")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "ORDERS_OK"), LESS_THAN, "expr"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "ORDERS_OK"), LESS_THAN, "expr"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( anyTree( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L)))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))))); } @@ -849,19 +849,19 @@ public void testDynamicFilterAliasDeDuplicated() "WHERE f.nationkey >= mod(d.nationkey, 2) AND f.suppkey >= mod(d.nationkey, 2)", anyTree( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "nationkey"), new SymbolReference(BIGINT, "mod")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "suppkey"), new SymbolReference(BIGINT, "mod")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "nationkey"), new Reference(BIGINT, "mod")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "suppkey"), new Reference(BIGINT, "mod")))), join(INNER, builder -> builder .dynamicFilter( ImmutableList.of( - new DynamicFilterPattern(new SymbolReference(BIGINT, "nationkey"), GREATER_THAN_OR_EQUAL, "mod"), - new DynamicFilterPattern(new SymbolReference(BIGINT, "suppkey"), GREATER_THAN_OR_EQUAL, "mod"))) + new DynamicFilterPattern(new Reference(BIGINT, "nationkey"), GREATER_THAN_OR_EQUAL, "mod"), + new DynamicFilterPattern(new Reference(BIGINT, "suppkey"), GREATER_THAN_OR_EQUAL, "mod"))) .left( anyTree( tableScan("supplier", ImmutableMap.of("nationkey", "nationkey", "suppkey", "suppkey")))) .right( anyTree( project( - ImmutableMap.of("mod", expression(new FunctionCall(MOD, ImmutableList.of(new SymbolReference(BIGINT, "n_nationkey"), new Constant(BIGINT, 2L))))), + ImmutableMap.of("mod", expression(new Call(MOD, ImmutableList.of(new Reference(BIGINT, "n_nationkey"), new Constant(BIGINT, 2L))))), tableScan("nation", ImmutableMap.of("n_nationkey", "nationkey"))))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index 4c1e3cfa05ee..3cffa4f87554 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -40,19 +40,19 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.TypeSignatureProvider; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; +import io.trino.sql.ir.In; import io.trino.sql.ir.IrUtils; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; @@ -99,9 +99,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.RealType.REAL; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.or; @@ -249,7 +249,7 @@ public void testGroupByEmpty() { PlanNode node = new AggregationNode( newId(), - filter(baseTableScan, FALSE_LITERAL), + filter(baseTableScan, FALSE), ImmutableMap.of(), globalAggregation(), ImmutableList.of(), @@ -259,7 +259,7 @@ public void testGroupByEmpty() Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node); - assertThat(effectivePredicate).isEqualTo(TRUE_LITERAL); + assertThat(effectivePredicate).isEqualTo(TRUE); } @Test @@ -455,7 +455,7 @@ public void testTableScan() false, Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node); - assertThat(effectivePredicate).isEqualTo(TRUE_LITERAL); + assertThat(effectivePredicate).isEqualTo(TRUE); node = new TableScanNode( newId(), @@ -467,7 +467,7 @@ public void testTableScan() false, Optional.empty()); effectivePredicate = effectivePredicateExtractor.extract(SESSION, node); - assertThat(effectivePredicate).isEqualTo(FALSE_LITERAL); + assertThat(effectivePredicate).isEqualTo(FALSE); TupleDomain predicate = TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(A), Domain.singleValue(BIGINT, 1L))); node = new TableScanNode( @@ -533,7 +533,7 @@ public void testTableScan() false, Optional.empty()); effectivePredicate = effectivePredicateExtractor.extract(SESSION, node); - assertThat(effectivePredicate).isEqualTo(TRUE_LITERAL); + assertThat(effectivePredicate).isEqualTo(TRUE); } @Test @@ -548,7 +548,7 @@ public void testValues() ImmutableList.of( new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(bigintLiteral(2))))) - )).isEqualTo(new InPredicate(AE, ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))); + )).isEqualTo(new In(AE, ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))); // one column with null assertThat(effectivePredicateExtractor.extract( @@ -561,8 +561,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(2))), new Row(ImmutableList.of(new Constant(BIGINT, null))))))) .isEqualTo(or( - new InPredicate(AE, ImmutableList.of(bigintLiteral(1), bigintLiteral(2))), - new IsNullPredicate(AE))); + new In(AE, ImmutableList.of(bigintLiteral(1), bigintLiteral(2))), + new IsNull(AE))); // all nulls assertThat(effectivePredicateExtractor.extract( @@ -571,7 +571,7 @@ public void testValues() newId(), ImmutableList.of(A), ImmutableList.of(new Row(ImmutableList.of(new Constant(BIGINT, null))))))) - .isEqualTo(new IsNullPredicate(AE)); + .isEqualTo(new IsNull(AE)); // nested row assertThat(effectivePredicateExtractor.extract( @@ -580,7 +580,7 @@ public void testValues() newId(), ImmutableList.of(R), ImmutableList.of(new Row(ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new Constant(UNKNOWN, null))))))))) - .isEqualTo(TRUE_LITERAL); + .isEqualTo(TRUE); // many rows List rows = IntStream.range(0, 500) @@ -594,7 +594,7 @@ public void testValues() newId(), ImmutableList.of(A), rows) - )).isEqualTo(new BetweenPredicate(AE, bigintLiteral(0), bigintLiteral(499))); + )).isEqualTo(new Between(AE, bigintLiteral(0), bigintLiteral(499))); // NaN assertThat(effectivePredicateExtractor.extract( @@ -603,7 +603,7 @@ public void testValues() newId(), ImmutableList.of(new Symbol(DOUBLE, "c")), ImmutableList.of(new Row(ImmutableList.of(doubleLiteral(Double.NaN))))) - )).isEqualTo(new NotExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "c")))); + )).isEqualTo(new Not(new IsNull(new Reference(DOUBLE, "c")))); // NaN and NULL assertThat(effectivePredicateExtractor.extract( @@ -614,7 +614,7 @@ public void testValues() ImmutableList.of( new Row(ImmutableList.of(new Constant(DOUBLE, null))), new Row(ImmutableList.of(doubleLiteral(Double.NaN))))) - )).isEqualTo(TRUE_LITERAL); + )).isEqualTo(TRUE); // NaN and value assertThat(effectivePredicateExtractor.extract( @@ -625,7 +625,7 @@ public void testValues() ImmutableList.of( new Row(ImmutableList.of(doubleLiteral(42.))), new Row(ImmutableList.of(doubleLiteral(Double.NaN))))) - )).isEqualTo(new NotExpression(new IsNullPredicate(new SymbolReference(DOUBLE, "x")))); + )).isEqualTo(new Not(new IsNull(new Reference(DOUBLE, "x")))); // Real NaN assertThat(effectivePredicateExtractor.extract( @@ -634,7 +634,7 @@ public void testValues() newId(), ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(new Cast(doubleLiteral(Double.NaN), REAL))))))) - .isEqualTo(new NotExpression(new IsNullPredicate(DE))); + .isEqualTo(new Not(new IsNull(DE))); // multiple columns assertThat(effectivePredicateExtractor.extract( @@ -646,8 +646,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(1), bigintLiteral(100))), new Row(ImmutableList.of(bigintLiteral(2), bigintLiteral(200))))))) .isEqualTo(and( - new InPredicate(AE, ImmutableList.of(bigintLiteral(1), bigintLiteral(2))), - new InPredicate(BE, ImmutableList.of(bigintLiteral(100), bigintLiteral(200))))); + new In(AE, ImmutableList.of(bigintLiteral(1), bigintLiteral(2))), + new In(BE, ImmutableList.of(bigintLiteral(100), bigintLiteral(200))))); // multiple columns with null assertThat(effectivePredicateExtractor.extract( @@ -659,16 +659,16 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(1), new Constant(BIGINT, null))), new Row(ImmutableList.of(new Constant(BIGINT, null), bigintLiteral(200))))) )).isEqualTo(and( - or(new ComparisonExpression(EQUAL, AE, bigintLiteral(1)), new IsNullPredicate(AE)), - or(new ComparisonExpression(EQUAL, BE, bigintLiteral(200)), new IsNullPredicate(BE)))); + or(new Comparison(EQUAL, AE, bigintLiteral(1)), new IsNull(AE)), + or(new Comparison(EQUAL, BE, bigintLiteral(200)), new IsNull(BE)))); // non-deterministic ResolvedFunction rand = functionResolution.resolveFunction("rand", ImmutableList.of()); ValuesNode node = new ValuesNode( newId(), ImmutableList.of(A, B), - ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new FunctionCall(rand, ImmutableList.of()))))); - assertThat(extract(node)).isEqualTo(new ComparisonExpression(EQUAL, AE, bigintLiteral(1))); + ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new Call(rand, ImmutableList.of()))))); + assertThat(extract(node)).isEqualTo(new Comparison(EQUAL, AE, bigintLiteral(1))); // non-constant assertThat(effectivePredicateExtractor.extract( @@ -679,7 +679,7 @@ public void testValues() ImmutableList.of( new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(BE)))) - )).isEqualTo(TRUE_LITERAL); + )).isEqualTo(TRUE); // non-comparable and non-orderable assertThat(effectivePredicateExtractor.extract( @@ -690,7 +690,7 @@ public void testValues() ImmutableList.of( new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(bigintLiteral(2))))) - )).isEqualTo(TRUE_LITERAL); + )).isEqualTo(TRUE); } private Expression extract(PlanNode node) @@ -829,7 +829,7 @@ public void testInnerJoinWithFalseFilter() leftScan.getOutputSymbols(), rightScan.getOutputSymbols(), false, - Optional.of(FALSE_LITERAL), + Optional.of(FALSE), Optional.empty(), Optional.empty(), Optional.empty(), @@ -839,7 +839,7 @@ public void testInnerJoinWithFalseFilter() Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node); - assertThat(effectivePredicate).isEqualTo(FALSE_LITERAL); + assertThat(effectivePredicate).isEqualTo(FALSE); } @Test @@ -913,7 +913,7 @@ public void testLeftJoinWithFalseInner() lessThan(BE, AE), lessThan(CE, bigintLiteral(10)), equals(GE, bigintLiteral(10)))); - FilterNode right = filter(rightScan, FALSE_LITERAL); + FilterNode right = filter(rightScan, FALSE); PlanNode node = new JoinNode( newId(), JoinType.LEFT, @@ -1005,7 +1005,7 @@ public void testRightJoinWithFalseInner() Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); TableScanNode rightScan = tableScanNode(rightAssignments); - FilterNode left = filter(leftScan, FALSE_LITERAL); + FilterNode left = filter(leftScan, FALSE); FilterNode right = filter( rightScan, and( @@ -1089,29 +1089,29 @@ private static Expression doubleLiteral(double value) return new Constant(DOUBLE, value); } - private static ComparisonExpression equals(Expression expression1, Expression expression2) + private static Comparison equals(Expression expression1, Expression expression2) { - return new ComparisonExpression(EQUAL, expression1, expression2); + return new Comparison(EQUAL, expression1, expression2); } - private static ComparisonExpression lessThan(Expression expression1, Expression expression2) + private static Comparison lessThan(Expression expression1, Expression expression2) { - return new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, expression1, expression2); + return new Comparison(Comparison.Operator.LESS_THAN, expression1, expression2); } - private static ComparisonExpression lessThanOrEqual(Expression expression1, Expression expression2) + private static Comparison lessThanOrEqual(Expression expression1, Expression expression2) { - return new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, expression1, expression2); + return new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, expression1, expression2); } - private static ComparisonExpression greaterThan(Expression expression1, Expression expression2) + private static Comparison greaterThan(Expression expression1, Expression expression2) { - return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, expression1, expression2); + return new Comparison(Comparison.Operator.GREATER_THAN, expression1, expression2); } - private static IsNullPredicate isNull(Expression expression) + private static IsNull isNull(Expression expression) { - return new IsNullPredicate(expression); + return new IsNull(expression); } private static ResolvedFunction fakeFunction(String name) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java index fd7990f84880..09b1b29e02cf 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java @@ -22,19 +22,19 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.scalar.TryFunction; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.type.FunctionType; import io.trino.type.UnknownType; @@ -52,8 +52,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.EqualityInference.isInferenceCandidate; import static io.trino.type.UnknownType.UNKNOWN; @@ -72,21 +72,21 @@ public void testDoesNotInferRedundantStraddlingPredicates() { EqualityInference inference = new EqualityInference( equals("a1", "b1"), - equals(add(new SymbolReference(BIGINT, "a1"), number(1)), number(0)), - equals(new SymbolReference(BIGINT, "a2"), add(new SymbolReference(BIGINT, "a1"), number(2))), - equals(new SymbolReference(BIGINT, "a1"), add("a3", "b3")), - equals(new SymbolReference(BIGINT, "b2"), add("a4", "b4"))); + equals(add(new Reference(BIGINT, "a1"), number(1)), number(0)), + equals(new Reference(BIGINT, "a2"), add(new Reference(BIGINT, "a1"), number(2))), + equals(new Reference(BIGINT, "a1"), add("a3", "b3")), + equals(new Reference(BIGINT, "b2"), add("a4", "b4"))); EqualityInference.EqualityPartition partition = inference.generateEqualitiesPartitionedBy(symbols("a1", "a2", "a3", "a4")); assertThat(partition.getScopeEqualities()).containsExactly( - equals(number(0), add(new SymbolReference(BIGINT, "a1"), number(1))), - equals(new SymbolReference(BIGINT, "a2"), add(new SymbolReference(BIGINT, "a1"), number(2)))); + equals(number(0), add(new Reference(BIGINT, "a1"), number(1))), + equals(new Reference(BIGINT, "a2"), add(new Reference(BIGINT, "a1"), number(2)))); assertThat(partition.getScopeComplementEqualities()).containsExactly( - equals(number(0), add(new SymbolReference(BIGINT, "b1"), number(1)))); + equals(number(0), add(new Reference(BIGINT, "b1"), number(1)))); // there shouldn't be equality a2 = b1 + 1 as it can be derived from a2 = a1 + 1, a1 = b1 assertThat(partition.getScopeStraddlingEqualities()).containsExactly( equals("a1", "b1"), - equals(new SymbolReference(BIGINT, "a1"), add("a3", "b3")), - equals(new SymbolReference(BIGINT, "b2"), add("a4", "b4"))); + equals(new Reference(BIGINT, "a1"), add("a3", "b3")), + equals(new Reference(BIGINT, "b2"), add("a4", "b4"))); } @Test @@ -109,8 +109,8 @@ public void testTransitivity() assertThat(inference.rewrite(someExpression("a1", "a2"), symbols("b1", "d2", "c3"))).isEqualTo(someExpression("b1", "d2")); // Both starting expressions should canonicalize to the same expression - assertThat(inference.getScopedCanonical(new SymbolReference(BIGINT, "a2"), matchesSymbols("c2", "d2"))).isEqualTo(inference.getScopedCanonical(new SymbolReference(BIGINT, "b2"), matchesSymbols("c2", "d2"))); - Expression canonical = inference.getScopedCanonical(new SymbolReference(BIGINT, "a2"), matchesSymbols("c2", "d2")); + assertThat(inference.getScopedCanonical(new Reference(BIGINT, "a2"), matchesSymbols("c2", "d2"))).isEqualTo(inference.getScopedCanonical(new Reference(BIGINT, "b2"), matchesSymbols("c2", "d2"))); + Expression canonical = inference.getScopedCanonical(new Reference(BIGINT, "a2"), matchesSymbols("c2", "d2")); // Given multiple translatable candidates, should choose the canonical assertThat(inference.rewrite(someExpression("a2", "b2"), symbols("c2", "d2"))).isEqualTo(someExpression(canonical, canonical)); @@ -155,21 +155,21 @@ public void testExtractInferrableEqualities() and(equals("a1", "b1"), equals("b1", "c1"), someExpression("c1", "d1"))); // Able to rewrite to c1 due to equalities - assertThat(new SymbolReference(BIGINT, "c1")).isEqualTo(inference.rewrite(new SymbolReference(BIGINT, "a1"), symbols("c1"))); + assertThat(new Reference(BIGINT, "c1")).isEqualTo(inference.rewrite(new Reference(BIGINT, "a1"), symbols("c1"))); // But not be able to rewrite to d1 which is not connected via equality - assertThat(inference.rewrite(new SymbolReference(BIGINT, "a1"), symbols("d1"))).isNull(); + assertThat(inference.rewrite(new Reference(BIGINT, "a1"), symbols("d1"))).isNull(); } @Test public void testEqualityPartitionGeneration() { EqualityInference inference = new EqualityInference( - equals(new SymbolReference(BIGINT, "a1"), new SymbolReference(BIGINT, "b1")), - equals(add("a1", "a1"), multiply(new SymbolReference(BIGINT, "a1"), number(2))), - equals(new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "c1")), - equals(add("a1", "a1"), new SymbolReference(BIGINT, "c1")), - equals(add("a1", "b1"), new SymbolReference(BIGINT, "c1"))); + equals(new Reference(BIGINT, "a1"), new Reference(BIGINT, "b1")), + equals(add("a1", "a1"), multiply(new Reference(BIGINT, "a1"), number(2))), + equals(new Reference(BIGINT, "b1"), new Reference(BIGINT, "c1")), + equals(add("a1", "a1"), new Reference(BIGINT, "c1")), + equals(add("a1", "b1"), new Reference(BIGINT, "c1"))); EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(ImmutableSet.of()); // Cannot generate any scope equalities with no matching symbols @@ -261,18 +261,18 @@ public void testMultipleEqualitySetsPredicateGeneration() public void testSubExpressionRewrites() { EqualityInference inference = new EqualityInference( - equals(new SymbolReference(BIGINT, "a1"), add("b", "c")), // a1 = b + c - equals(new SymbolReference(BIGINT, "a2"), multiply(new SymbolReference(BIGINT, "b"), add("b", "c"))), // a2 = b * (b + c) - equals(new SymbolReference(BIGINT, "a3"), multiply(new SymbolReference(BIGINT, "a1"), add("b", "c")))); // a3 = a1 * (b + c) + equals(new Reference(BIGINT, "a1"), add("b", "c")), // a1 = b + c + equals(new Reference(BIGINT, "a2"), multiply(new Reference(BIGINT, "b"), add("b", "c"))), // a2 = b * (b + c) + equals(new Reference(BIGINT, "a3"), multiply(new Reference(BIGINT, "a1"), add("b", "c")))); // a3 = a1 * (b + c) // Expression (b + c) should get entirely rewritten as a1 - assertThat(inference.rewrite(add("b", "c"), symbols("a1", "a2"))).isEqualTo(new SymbolReference(BIGINT, "a1")); + assertThat(inference.rewrite(add("b", "c"), symbols("a1", "a2"))).isEqualTo(new Reference(BIGINT, "a1")); // Only the sub-expression (b + c) should get rewritten in terms of a* - assertThat(inference.rewrite(multiply(new SymbolReference(BIGINT, "ax"), add("b", "c")), symbols("ax", "a1", "a2", "a3"))).isEqualTo(multiply(new SymbolReference(BIGINT, "ax"), new SymbolReference(BIGINT, "a1"))); + assertThat(inference.rewrite(multiply(new Reference(BIGINT, "ax"), add("b", "c")), symbols("ax", "a1", "a2", "a3"))).isEqualTo(multiply(new Reference(BIGINT, "ax"), new Reference(BIGINT, "a1"))); // To be compliant, could rewrite either the whole expression, or just the sub-expression. Rewriting larger expressions are preferred - assertThat(inference.rewrite(multiply(new SymbolReference(BIGINT, "a1"), add("b", "c")), symbols("a1", "a2", "a3"))).isEqualTo(new SymbolReference(BIGINT, "a3")); + assertThat(inference.rewrite(multiply(new Reference(BIGINT, "a1"), add("b", "c")), symbols("a1", "a2", "a3"))).isEqualTo(new Reference(BIGINT, "a3")); } @Test @@ -281,15 +281,15 @@ public void testConstantEqualities() EqualityInference inference = new EqualityInference( equals("a1", "b1"), equals("b1", "c1"), - equals(new SymbolReference(BIGINT, "c1"), number(1))); + equals(new Reference(BIGINT, "c1"), number(1))); // Should always prefer a constant if available (constant is part of all scopes) - assertThat(inference.rewrite(new SymbolReference(BIGINT, "a1"), symbols("a1", "b1"))).isEqualTo(number(1)); + assertThat(inference.rewrite(new Reference(BIGINT, "a1"), symbols("a1", "b1"))).isEqualTo(number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbols("a1", "b1")); - assertThat(equalitiesAsSets(equalityPartition.getScopeEqualities())).isEqualTo(set(set(new SymbolReference(BIGINT, "a1"), number(1)), set(new SymbolReference(BIGINT, "b1"), number(1)))); - assertThat(equalitiesAsSets(equalityPartition.getScopeComplementEqualities())).isEqualTo(set(set(new SymbolReference(BIGINT, "c1"), number(1)))); + assertThat(equalitiesAsSets(equalityPartition.getScopeEqualities())).isEqualTo(set(set(new Reference(BIGINT, "a1"), number(1)), set(new Reference(BIGINT, "b1"), number(1)))); + assertThat(equalitiesAsSets(equalityPartition.getScopeComplementEqualities())).isEqualTo(set(set(new Reference(BIGINT, "c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertThat(equalityPartition.getScopeStraddlingEqualities().isEmpty()).isTrue(); @@ -299,36 +299,36 @@ public void testConstantEqualities() public void testEqualityGeneration() { EqualityInference inference = new EqualityInference( - equals(new SymbolReference(BIGINT, "a1"), add("b", "c")), // a1 = b + c - equals(new SymbolReference(BIGINT, "e1"), add("b", "d")), // e1 = b + d + equals(new Reference(BIGINT, "a1"), add("b", "c")), // a1 = b + c + equals(new Reference(BIGINT, "e1"), add("b", "d")), // e1 = b + d equals("c", "d")); - Expression scopedCanonical = inference.getScopedCanonical(new SymbolReference(BIGINT, "e1"), symbolBeginsWith("a")); - assertThat(scopedCanonical).isEqualTo(new SymbolReference(BIGINT, "a1")); + Expression scopedCanonical = inference.getScopedCanonical(new Reference(BIGINT, "e1"), symbolBeginsWith("a")); + assertThat(scopedCanonical).isEqualTo(new Reference(BIGINT, "a1")); } @Test public void testExpressionsThatMayReturnNullOnNonNullInput() { List candidates = ImmutableList.of( - new Cast(new SymbolReference(BIGINT, "b"), BIGINT, true), // try_cast + new Cast(new Reference(BIGINT, "b"), BIGINT, true), // try_cast functionResolution .functionCallBuilder(TryFunction.NAME) - .addArgument(new FunctionType(ImmutableList.of(), VARCHAR), new LambdaExpression(ImmutableList.of(), new SymbolReference(BIGINT, "b"))) + .addArgument(new FunctionType(ImmutableList.of(), VARCHAR), new Lambda(ImmutableList.of(), new Reference(BIGINT, "b"))) .build(), - new NullIfExpression(new SymbolReference(BIGINT, "b"), number(1)), - new InPredicate(new SymbolReference(BIGINT, "b"), ImmutableList.of(new Constant(UnknownType.UNKNOWN, null))), - new SearchedCaseExpression(ImmutableList.of(new WhenClause(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "b"))), new Constant(UnknownType.UNKNOWN, null))), Optional.empty()), - new SimpleCaseExpression(new SymbolReference(BIGINT, "b"), ImmutableList.of(new WhenClause(number(1), new Constant(UnknownType.UNKNOWN, null))), Optional.empty())); + new NullIf(new Reference(BIGINT, "b"), number(1)), + new In(new Reference(BIGINT, "b"), ImmutableList.of(new Constant(UnknownType.UNKNOWN, null))), + new Case(ImmutableList.of(new WhenClause(new Not(new IsNull(new Reference(BIGINT, "b"))), new Constant(UnknownType.UNKNOWN, null))), Optional.empty()), + new Switch(new Reference(BIGINT, "b"), ImmutableList.of(new WhenClause(number(1), new Constant(UnknownType.UNKNOWN, null))), Optional.empty())); for (Expression candidate : candidates) { EqualityInference inference = new EqualityInference( - equals(new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "x")), - equals(new SymbolReference(BIGINT, "a"), candidate)); + equals(new Reference(BIGINT, "b"), new Reference(BIGINT, "x")), + equals(new Reference(BIGINT, "a"), candidate)); List equalities = inference.generateEqualitiesPartitionedBy(symbols("b")).getScopeStraddlingEqualities(); assertThat(equalities.size()).isEqualTo(1); - assertThat(equalities.get(0).equals(equals(new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "b"))) || equalities.get(0).equals(equals(new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "x")))).isTrue(); + assertThat(equalities.get(0).equals(equals(new Reference(BIGINT, "x"), new Reference(BIGINT, "b"))) || equalities.get(0).equals(equals(new Reference(BIGINT, "b"), new Reference(BIGINT, "x")))).isTrue(); } } @@ -347,37 +347,37 @@ private static Predicate matchesStraddlingScope(Predicate sy private static Expression someExpression(String symbol1, String symbol2) { - return someExpression(new SymbolReference(BIGINT, symbol1), new SymbolReference(BIGINT, symbol2)); + return someExpression(new Reference(BIGINT, symbol1), new Reference(BIGINT, symbol2)); } private static Expression someExpression(Expression expression1, Expression expression2) { - return new ComparisonExpression(GREATER_THAN, expression1, expression2); + return new Comparison(GREATER_THAN, expression1, expression2); } private static Expression add(String symbol1, String symbol2) { - return add(new SymbolReference(BIGINT, symbol1), new SymbolReference(BIGINT, symbol2)); + return add(new Reference(BIGINT, symbol1), new Reference(BIGINT, symbol2)); } private static Expression add(Expression expression1, Expression expression2) { - return new ArithmeticBinaryExpression(ADD_BIGINT, ArithmeticBinaryExpression.Operator.ADD, expression1, expression2); + return new Arithmetic(ADD_BIGINT, Arithmetic.Operator.ADD, expression1, expression2); } private static Expression multiply(Expression expression1, Expression expression2) { - return new ArithmeticBinaryExpression(MULTIPLY_BIGINT, ArithmeticBinaryExpression.Operator.MULTIPLY, expression1, expression2); + return new Arithmetic(MULTIPLY_BIGINT, Arithmetic.Operator.MULTIPLY, expression1, expression2); } private static Expression equals(String symbol1, String symbol2) { - return equals(new SymbolReference(BIGINT, symbol1), new SymbolReference(BIGINT, symbol2)); + return equals(new Reference(BIGINT, symbol1), new Reference(BIGINT, symbol2)); } private static Expression equals(Expression expression1, Expression expression2) { - return new ComparisonExpression(EQUAL, expression1, expression2); + return new Comparison(EQUAL, expression1, expression2); } private static Constant number(long number) @@ -438,10 +438,10 @@ private static Set> equalitiesAsSets(Iterable expres private static Set equalityAsSet(Expression expression) { - checkArgument(expression instanceof ComparisonExpression); - ComparisonExpression comparisonExpression = (ComparisonExpression) expression; - checkArgument(comparisonExpression.getOperator() == EQUAL); - return ImmutableSet.of(comparisonExpression.getLeft(), comparisonExpression.getRight()); + checkArgument(expression instanceof Comparison); + Comparison comparison = (Comparison) expression; + checkArgument(comparison.operator() == EQUAL); + return ImmutableSet.of(comparison.left(), comparison.right()); } @SafeVarargs diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java index 28dbcd4ea1a2..a8121e73e662 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java @@ -29,11 +29,11 @@ import io.trino.operator.table.json.JsonTablePlanUnion; import io.trino.operator.table.json.JsonTableQueryColumn; import io.trino.operator.table.json.JsonTableValueColumn; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.optimizations.PlanNodeSearcher; import io.trino.sql.planner.plan.TableFunctionNode; @@ -56,7 +56,7 @@ import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.ExpressionAnalyzer.JSON_NO_PARAMETERS_ROW_TYPE; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.planner.JsonTablePlanComparator.planComparator; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.PathNodes.contextVariable; @@ -108,7 +108,7 @@ public void testJsonTableInitialPlan() ImmutableList.of("json_col", "int_col", "bigint_col", "formatted_varchar_col"), anyTree( project( - ImmutableMap.of("formatted_varchar_col", expression(new FunctionCall(JSON_TO_VARCHAR, ImmutableList.of(new SymbolReference(VARCHAR, "varchar_col"), new Constant(TINYINT, 1L), FALSE_LITERAL)))), + ImmutableMap.of("formatted_varchar_col", expression(new Call(JSON_TO_VARCHAR, ImmutableList.of(new Reference(VARCHAR, "varchar_col"), new Constant(TINYINT, 1L), FALSE)))), tableFunction(builder -> builder .name("$json_table") .addTableArgument( @@ -120,14 +120,14 @@ public void testJsonTableInitialPlan() .properOutputs(ImmutableList.of("bigint_col", "varchar_col")), project( ImmutableMap.of( - "context_item", expression(new FunctionCall(VARCHAR_TO_JSON, ImmutableList.of(new SymbolReference(VARCHAR, "json_col_coerced"), FALSE_LITERAL))), // apply input function to context item - "parameters_row", expression(new Cast(new Row(ImmutableList.of(new SymbolReference(INTEGER, "int_col"), new FunctionCall(VARCHAR_TO_JSON, ImmutableList.of(new SymbolReference(VARCHAR, "name_coerced"), FALSE_LITERAL)))), rowType(field("id", INTEGER), field("name", JSON_2016))))), // apply input function to formatted path parameter and gather path parameters in a row + "context_item", expression(new Call(VARCHAR_TO_JSON, ImmutableList.of(new Reference(VARCHAR, "json_col_coerced"), FALSE))), // apply input function to context item + "parameters_row", expression(new Cast(new Row(ImmutableList.of(new Reference(INTEGER, "int_col"), new Call(VARCHAR_TO_JSON, ImmutableList.of(new Reference(VARCHAR, "name_coerced"), FALSE)))), rowType(field("id", INTEGER), field("name", JSON_2016))))), // apply input function to formatted path parameter and gather path parameters in a row project(// coerce context item, path parameters and default expressions ImmutableMap.of( - "name_coerced", expression(new Cast(new SymbolReference(createVarcharType(5), "name"), VARCHAR)), // cast formatted path parameter to VARCHAR for the input function - "default_value_coerced", expression(new Cast(new SymbolReference(INTEGER, "default_value"), BIGINT)), // cast default value to BIGINT to match declared return type for the column - "json_col_coerced", expression(new Cast(new SymbolReference(createVarcharType(9), "json_col"), VARCHAR)), // cast context item to VARCHAR for the input function - "int_col_coerced", expression(new Cast(new SymbolReference(INTEGER, "int_col"), BIGINT))), // cast default value to BIGINT to match declared return type for the column + "name_coerced", expression(new Cast(new Reference(createVarcharType(5), "name"), VARCHAR)), // cast formatted path parameter to VARCHAR for the input function + "default_value_coerced", expression(new Cast(new Reference(INTEGER, "default_value"), BIGINT)), // cast default value to BIGINT to match declared return type for the column + "json_col_coerced", expression(new Cast(new Reference(createVarcharType(9), "json_col"), VARCHAR)), // cast context item to VARCHAR for the input function + "int_col_coerced", expression(new Cast(new Reference(INTEGER, "int_col"), BIGINT))), // cast default value to BIGINT to match declared return type for the column project(// pre-project context item, path parameters and default expressions ImmutableMap.of( "name", expression(new Constant(createVarcharType(5), Slices.utf8Slice("[ala]"))), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFiltersCollector.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFiltersCollector.java index 76478cfb7708..a3d424ba8d00 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFiltersCollector.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFiltersCollector.java @@ -37,9 +37,9 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static org.assertj.core.api.Assertions.assertThat; diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 7666d36afb54..60c06b7ae4b2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -30,20 +30,20 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; @@ -116,15 +116,15 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED; import static io.trino.sql.planner.assertions.PlanMatchPattern.DynamicFilterPattern; @@ -299,10 +299,10 @@ public void testAllFieldsDereferenceOnSubquery() any( project( ImmutableMap.of( - "output_1", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "row"), new Constant(INTEGER, 1L))), - "output_2", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "row"), new Constant(INTEGER, 2L)))), + "output_1", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "row"), new Constant(INTEGER, 1L))), + "output_2", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "row"), new Constant(INTEGER, 2L)))), project( - ImmutableMap.of("row", expression(new Row(ImmutableList.of(new SymbolReference(BIGINT, "min"), new SymbolReference(BIGINT, "max"))))), + ImmutableMap.of("row", expression(new Row(ImmutableList.of(new Reference(BIGINT, "min"), new Reference(BIGINT, "max"))))), aggregation( ImmutableMap.of( "min", aggregationFunction("min", ImmutableList.of("min_regionkey")), @@ -320,7 +320,7 @@ public void testAllFieldsDereferenceOnSubquery() @Test public void testAllFieldsDereferenceFromNonDeterministic() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction("rand", ImmutableList.of()), ImmutableList.of()); @@ -328,10 +328,10 @@ public void testAllFieldsDereferenceFromNonDeterministic() any( project( ImmutableMap.of( - "output_1", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 1L))), - "output_2", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 2L)))), + "output_1", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 1L))), + "output_2", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 2L)))), project( - ImmutableMap.of("row", expression(new Row(ImmutableList.of(new SymbolReference(DOUBLE, "rand"), new SymbolReference(DOUBLE, "rand"))))), + ImmutableMap.of("row", expression(new Row(ImmutableList.of(new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand"))))), values( ImmutableList.of("rand"), ImmutableList.of(ImmutableList.of(randomFunction))))))); @@ -340,8 +340,8 @@ public void testAllFieldsDereferenceFromNonDeterministic() any( project( ImmutableMap.of( - "output_1", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), new Constant(INTEGER, 1L))), - "output_2", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), new Constant(INTEGER, 2L)))), + "output_1", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), new Constant(INTEGER, 1L))), + "output_2", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), new Constant(INTEGER, 2L)))), values( ImmutableList.of("r"), ImmutableList.of(ImmutableList.of(new Row(ImmutableList.of(randomFunction, randomFunction)))))))); @@ -353,8 +353,8 @@ public void testAllFieldsDereferenceFromNonDeterministic() any( project( ImmutableMap.of( - "output_1", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 1L))), - "output_2", expression(new SubscriptExpression(DOUBLE, new SymbolReference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 2L)))), + "output_1", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 1L))), + "output_2", expression(new Subscript(DOUBLE, new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 2L)))), values( ImmutableList.of("row"), ImmutableList.of( @@ -383,11 +383,11 @@ public void testDistinctLimitOverInequalityJoin() node(DistinctLimitNode.class, anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(LESS_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), LESS_THAN, "L_ORDERKEY"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), LESS_THAN, "L_ORDERKEY"))) .left( - filter(TRUE_LITERAL, + filter(TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( any(tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey"))))) @@ -401,7 +401,7 @@ public void testDistinctLimitOverInequalityJoin() anyTree( join(INNER, builder -> builder .equiCriteria("O_SHIPPRIORITY", "L_LINENUMBER") - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"))) .left( anyTree(tableScan("orders", ImmutableMap.of( "O_SHIPPRIORITY", "shippriority", @@ -431,12 +431,12 @@ public void testInnerInequalityJoinNoEquiJoinConjuncts() assertPlan("SELECT 1 FROM orders o JOIN lineitem l ON o.orderkey < l.orderkey", anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(LESS_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new SymbolReference(BIGINT, "O_ORDERKEY"), LESS_THAN, "L_ORDERKEY"))) + .dynamicFilter(ImmutableList.of(new DynamicFilterPattern(new Reference(BIGINT, "O_ORDERKEY"), LESS_THAN, "L_ORDERKEY"))) .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( any(tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey")))))))); @@ -450,13 +450,13 @@ public void testInnerInequalityJoinWithEquiJoinConjuncts() anyNot(FilterNode.class, join(INNER, builder -> builder .equiCriteria("L_LINENUMBER", "O_SHIPPRIORITY") - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY"))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY"))) .dynamicFilter( ImmutableList.of( - new DynamicFilterPattern(new SymbolReference(INTEGER, "L_LINENUMBER"), EQUAL, "O_SHIPPRIORITY"), - new DynamicFilterPattern(new SymbolReference(BIGINT, "L_ORDERKEY"), GREATER_THAN, "O_ORDERKEY"))) + new DynamicFilterPattern(new Reference(INTEGER, "L_LINENUMBER"), EQUAL, "O_SHIPPRIORITY"), + new DynamicFilterPattern(new Reference(BIGINT, "L_ORDERKEY"), GREATER_THAN, "O_ORDERKEY"))) .left( - filter(TRUE_LITERAL, + filter(TRUE, tableScan("lineitem", ImmutableMap.of( "L_LINENUMBER", "linenumber", @@ -475,16 +475,16 @@ public void testLeftConvertedToInnerInequalityJoinNoEquiJoinConjuncts() assertPlan("SELECT 1 FROM orders o LEFT JOIN lineitem l ON o.orderkey < l.orderkey WHERE l.orderkey IS NOT NULL", anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "O_ORDERKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(LESS_THAN, new Reference(BIGINT, "O_ORDERKEY"), new Reference(BIGINT, "L_ORDERKEY")), join(INNER, builder -> builder .left( filter( - TRUE_LITERAL, + TRUE, tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))) .right( any( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "L_ORDERKEY"))), + new Not(new IsNull(new Reference(BIGINT, "L_ORDERKEY"))), tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey"))))))))); } @@ -529,10 +529,10 @@ public void testInequalityPredicatePushdownWithOuterJoin() anyTree( // predicate above outer join is not pushed to build side filter( - new ComparisonExpression(LESS_THAN, new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 24L)), new CoalesceExpression(new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 24L)), new Constant(BIGINT, 0L))), + new Comparison(LESS_THAN, new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 24L)), new Coalesce(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 24L)), new Constant(BIGINT, 0L))), join(LEFT, builder -> builder .equiCriteria("O_ORDERKEY", "L_ORDERKEY") - .filter(new ComparisonExpression(LESS_THAN, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 42L)), new SymbolReference(BIGINT, "EXPR"))) + .filter(new Comparison(LESS_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "O_CUSTKEY"), new Constant(BIGINT, 42L)), new Reference(BIGINT, "EXPR"))) .left( tableScan( "orders", @@ -542,7 +542,7 @@ public void testInequalityPredicatePushdownWithOuterJoin() .right( anyTree( project( - ImmutableMap.of("EXPR", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 42L)))), + ImmutableMap.of("EXPR", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "L_PARTKEY"), new Constant(BIGINT, 42L)))), tableScan( "lineitem", ImmutableMap.of( @@ -575,7 +575,7 @@ public void testUncorrelatedSubqueries() join(INNER, builder -> builder .equiCriteria("X", "Y") .left( - filter(TRUE_LITERAL, + filter(TRUE, tableScan("orders", ImmutableMap.of("X", "orderkey")))) .right( node(EnforceSingleRowNode.class, @@ -586,7 +586,7 @@ public void testUncorrelatedSubqueries() noSemiJoinRewrite(), anyTree( filter( - new SymbolReference(BOOLEAN, "S"), + new Reference(BOOLEAN, "S"), semiJoin("X", "Y", "S", anyTree( tableScan("orders", ImmutableMap.of("X", "orderkey"))), @@ -596,7 +596,7 @@ public void testUncorrelatedSubqueries() assertPlan("SELECT * FROM orders WHERE orderkey NOT IN (SELECT orderkey FROM lineitem WHERE linenumber < 0)", anyTree( filter( - new NotExpression(new SymbolReference(BOOLEAN, "S")), + new Not(new Reference(BOOLEAN, "S")), semiJoin("X", "Y", "S", tableScan("orders", ImmutableMap.of("X", "orderkey")), anyTree( @@ -617,7 +617,7 @@ public void testPushDownJoinConditionConjunctsToInnerSideBasedOnInheritedPredica equiJoinClause("NATION_REGIONKEY", "REGION_REGIONKEY"))) .left( filter( - new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(25), "NATION_NAME"), new Constant(createVarcharType(25), utf8Slice("blah"))), + new Comparison(EQUAL, new Reference(createVarcharType(25), "NATION_NAME"), new Constant(createVarcharType(25), utf8Slice("blah"))), constrainedTableScan( "nation", ImmutableMap.of(), @@ -627,7 +627,7 @@ public void testPushDownJoinConditionConjunctsToInnerSideBasedOnInheritedPredica .right( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(25), "REGION_NAME"), new Constant(createVarcharType(25), utf8Slice("blah"))), + new Comparison(EQUAL, new Reference(createVarcharType(25), "REGION_NAME"), new Constant(createVarcharType(25), utf8Slice("blah"))), constrainedTableScan( "region", ImmutableMap.of(), @@ -781,7 +781,7 @@ public void testCorrelatedSubqueries() OPTIMIZED, any( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "X"), new Constant(BIGINT, 3L)), + new Comparison(EQUAL, new Reference(BIGINT, "X"), new Constant(BIGINT, 3L)), tableScan("orders", ImmutableMap.of("X", "orderkey"))))); } @@ -805,7 +805,7 @@ public void testCorrelatedJoinWithLimit() assertPlan("SELECT regionkey, n.nationkey FROM region LEFT JOIN LATERAL (SELECT nationkey FROM nation WHERE region.regionkey = 3 LIMIT 2) n ON TRUE", any( join(LEFT, builder -> builder - .filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "region_regionkey"), new Constant(BIGINT, 3L))) + .filter(new Comparison(EQUAL, new Reference(BIGINT, "region_regionkey"), new Constant(BIGINT, 3L))) .left(tableScan("region", ImmutableMap.of("region_regionkey", "regionkey"))) .right( limit( @@ -880,10 +880,10 @@ public void testCorrelatedScalarSubqueryInSelect() noJoinReordering(), anyTree( filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "is_distinct"), - ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), - Optional.of(new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + new Switch( + new Reference(BOOLEAN, "is_distinct"), + ImmutableList.of(new WhenClause(TRUE, TRUE)), + Optional.of(new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), project( markDistinct("is_distinct", ImmutableList.of("unique"), join(LEFT, builder -> builder @@ -898,10 +898,10 @@ public void testCorrelatedScalarSubqueryInSelect() automaticJoinDistribution(), anyTree( filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "is_distinct"), - ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), - Optional.of(new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + new Switch( + new Reference(BOOLEAN, "is_distinct"), + ImmutableList.of(new WhenClause(TRUE, TRUE)), + Optional.of(new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), project( markDistinct("is_distinct", ImmutableList.of("unique"), join(LEFT, builder -> builder @@ -935,7 +935,7 @@ public void testStreamingAggregationForCorrelatedSubquery() tableScan("nation", ImmutableMap.of("n_name", "name", "n_regionkey", "regionkey")))), anyTree( project( - ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + ImmutableMap.of("non_null", expression(TRUE)), tableScan("region", ImmutableMap.of("r_name", "name")))))))); // Don't use equi-clauses to trigger replicated join @@ -954,7 +954,7 @@ public void testStreamingAggregationForCorrelatedSubquery() tableScan("nation", ImmutableMap.of("n_name", "name", "n_regionkey", "regionkey"))), anyTree( project( - ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + ImmutableMap.of("non_null", expression(TRUE)), tableScan("region", ImmutableMap.of("r_name", "name")))))))); } @@ -1030,7 +1030,7 @@ public void testCorrelatedInUncorrelatedFiltersPushDown() anyTree(tableScan("lineitem")), anyTree( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 7L)), // pushed down + new Comparison(LESS_THAN, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 7L)), // pushed down tableScan("orders", ImmutableMap.of("orderkey", "orderkey"))))))); } @@ -1061,7 +1061,7 @@ public void testDoubleNestedCorrelatedSubqueries() OPTIMIZED, anyTree( filter( - new SymbolReference(BOOLEAN, "OUTER_FILTER"), + new Reference(BOOLEAN, "OUTER_FILTER"), project( apply(ImmutableList.of("O", "C"), ImmutableMap.of("OUTER_FILTER", setExpression(new ApplyNode.In(new Symbol(UNKNOWN, "THREE"), new Symbol(UNKNOWN, "C")))), @@ -1100,8 +1100,8 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() output( strictProject( ImmutableMap.of( - "ORDERKEY", expression(new SymbolReference(BIGINT, "ORDERKEY")), - "exists", expression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "FINAL_COUNT"), new Constant(BIGINT, 0L)))), + "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY")), + "exists", expression(new Comparison(GREATER_THAN, new Reference(BIGINT, "FINAL_COUNT"), new Constant(BIGINT, 0L)))), aggregation( singleGroupingSet("ORDERKEY", "UNIQUE"), ImmutableMap.of(Optional.of("FINAL_COUNT"), aggregationFunction("count", ImmutableList.of())), @@ -1110,13 +1110,13 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(EQUAL, new Constant(BIGINT, 3L), new SymbolReference(BIGINT, "ORDERKEY"))) + .filter(new Comparison(EQUAL, new Constant(BIGINT, 3L), new Reference(BIGINT, "ORDERKEY"))) .left( assignUniqueId( "UNIQUE", tableScan("orders", ImmutableMap.of("ORDERKEY", "orderkey")))) .right( - project(ImmutableMap.of("NON_NULL", expression(TRUE_LITERAL)), + project(ImmutableMap.of("NON_NULL", expression(TRUE)), node(ValuesNode.class)))))))); } @@ -1139,7 +1139,7 @@ public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin() ImmutableList.of("non_null"), Optional.empty(), SINGLE, - project(ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + project(ImmutableMap.of("non_null", expression(TRUE)), aggregation( singleGroupingSet("o_orderkey", "o_custkey"), ImmutableMap.of(), @@ -1156,10 +1156,10 @@ public void testCorrelatedDistinctGroupedAggregationRewriteToLeftOuterJoin() "SELECT (SELECT count(DISTINCT o.orderkey) FROM orders o WHERE c.custkey = o.custkey GROUP BY o.orderstatus), c.custkey FROM customer c", output( project(filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "is_distinct"), - ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), - Optional.of(new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + new Switch( + new Reference(BOOLEAN, "is_distinct"), + ImmutableList.of(new WhenClause(TRUE, TRUE)), + Optional.of(new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), project(markDistinct( "is_distinct", ImmutableList.of("unique"), @@ -1299,7 +1299,7 @@ public void testPickTableLayoutWithFilter() "SELECT orderkey FROM orders WHERE orderkey=5", output( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 5L)), constrainedTableScanWithTableLayout( "orders", ImmutableMap.of(), @@ -1528,7 +1528,7 @@ public void testCorrelatedIn() "SELECT name FROM region r WHERE regionkey IN (SELECT regionkey FROM nation WHERE name < r.name)", output( project( - ImmutableMap.of("region_name", expression(new SymbolReference(VARCHAR, "region_name"))), + ImmutableMap.of("region_name", expression(new Reference(VARCHAR, "region_name"))), aggregation( singleGroupingSet("region_regionkey", "region_name", "unique"), ImmutableMap.of(), @@ -1536,25 +1536,25 @@ public void testCorrelatedIn() SINGLE, project( ImmutableMap.of( - "region_regionkey", expression(new SymbolReference(BIGINT, "region_regionkey")), - "region_name", expression(new SymbolReference(VARCHAR, "region_name")), - "unique", expression(new SymbolReference(BIGINT, "unique"))), + "region_regionkey", expression(new Reference(BIGINT, "region_regionkey")), + "region_name", expression(new Reference(VARCHAR, "region_name")), + "unique", expression(new Reference(BIGINT, "unique"))), filter( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(BIGINT, "region_regionkey")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "region_regionkey"), new SymbolReference(BIGINT, "nation_regionkey")), new IsNullPredicate(new SymbolReference(BIGINT, "nation_regionkey")))), new ComparisonExpression(LESS_THAN, new SymbolReference(VARCHAR, "nation_name"), new SymbolReference(VARCHAR, "region_name")))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new IsNull(new Reference(BIGINT, "region_regionkey")), new Comparison(EQUAL, new Reference(BIGINT, "region_regionkey"), new Reference(BIGINT, "nation_regionkey")), new IsNull(new Reference(BIGINT, "nation_regionkey")))), new Comparison(LESS_THAN, new Reference(VARCHAR, "nation_name"), new Reference(VARCHAR, "region_name")))), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new PlanMatchPattern.DynamicFilterPattern(new SymbolReference(VARCHAR, "region_name"), GREATER_THAN, "nation_name"))) + .dynamicFilter(ImmutableList.of(new PlanMatchPattern.DynamicFilterPattern(new Reference(VARCHAR, "region_name"), GREATER_THAN, "nation_name"))) .left( assignUniqueId( "unique", filter( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "region_regionkey"))), + new Not(new IsNull(new Reference(BIGINT, "region_regionkey"))), tableScan("region", ImmutableMap.of( "region_regionkey", "regionkey", "region_name", "name"))))) .right( any( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "nation_regionkey"))), + new Not(new IsNull(new Reference(BIGINT, "nation_regionkey"))), tableScan("nation", ImmutableMap.of( "nation_name", "name", "nation_regionkey", "regionkey")))))))))))); @@ -1574,14 +1574,14 @@ public void testCorrelatedExists() SINGLE, project( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(VARCHAR, "nation_name"), new SymbolReference(VARCHAR, "region_name")), + new Comparison(LESS_THAN, new Reference(VARCHAR, "nation_name"), new Reference(VARCHAR, "region_name")), join(INNER, builder -> builder - .dynamicFilter(ImmutableList.of(new PlanMatchPattern.DynamicFilterPattern(new SymbolReference(VARCHAR, "region_name"), GREATER_THAN, "nation_name"))) + .dynamicFilter(ImmutableList.of(new PlanMatchPattern.DynamicFilterPattern(new Reference(VARCHAR, "region_name"), GREATER_THAN, "nation_name"))) .left( assignUniqueId( "unique", filter( - TRUE_LITERAL, + TRUE, tableScan("region", ImmutableMap.of( "region_regionkey", "regionkey", "region_name", "name"))))) @@ -1623,9 +1623,9 @@ public void testOffset() "SELECT name FROM nation OFFSET 2 ROWS", any( strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), @@ -1637,9 +1637,9 @@ public void testOffset() "SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS", any( strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), @@ -1654,9 +1654,9 @@ public void testOffset() "SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY", any( strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), @@ -1673,9 +1673,9 @@ public void testOffset() "SELECT name FROM nation OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY", any( strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), @@ -1693,7 +1693,7 @@ public void testWithTies() "SELECT name, regionkey FROM nation ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES", any( strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name")), "regionkey", expression(new SymbolReference(BIGINT, "regionkey"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name")), "regionkey", expression(new Reference(BIGINT, "regionkey"))), topNRanking( pattern -> pattern .specification( @@ -1713,14 +1713,14 @@ public void testWithTies() "SELECT name, regionkey FROM nation ORDER BY regionkey OFFSET 10 ROWS FETCH FIRST 6 ROWS WITH TIES", any( strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name")), "regionkey", expression(new SymbolReference(BIGINT, "regionkey"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name")), "regionkey", expression(new Reference(BIGINT, "regionkey"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 10L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 10L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), strictProject( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name")), "regionkey", expression(new SymbolReference(BIGINT, "regionkey"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name")), "regionkey", expression(new Reference(BIGINT, "regionkey"))), topNRanking( pattern -> pattern .specification( @@ -1874,10 +1874,10 @@ public void testRedundantHashRemovalForUnionAll() project( node(AggregationNode.class, exchange(LOCAL, REPARTITION, - project(ImmutableMap.of("hash", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference(BIGINT, "nationkey"))), new Constant(BIGINT, 0L)))))), + project(ImmutableMap.of("hash", expression(new Call(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new Coalesce(new Call(HASH_CODE, ImmutableList.of(new Reference(BIGINT, "nationkey"))), new Constant(BIGINT, 0L)))))), node(AggregationNode.class, tableScan("customer", ImmutableMap.of("nationkey", "nationkey")))), - project(ImmutableMap.of("hash_1", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference(BIGINT, "nationkey_6"))), new Constant(BIGINT, 0L)))))), + project(ImmutableMap.of("hash_1", expression(new Call(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new Coalesce(new Call(HASH_CODE, ImmutableList.of(new Reference(BIGINT, "nationkey_6"))), new Constant(BIGINT, 0L)))))), node(AggregationNode.class, tableScan("customer", ImmutableMap.of("nationkey_6", "nationkey"))))))))); } @@ -1897,8 +1897,8 @@ public void testRedundantHashRemovalForMarkDistinct() node(MarkDistinctNode.class, anyTree( project(ImmutableMap.of( - "hash_1", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference(BIGINT, "suppkey"))), new Constant(BIGINT, 0L))))), - "hash_2", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference(BIGINT, "partkey"))), new Constant(BIGINT, 0L)))))), + "hash_1", expression(new Call(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new Coalesce(new Call(HASH_CODE, ImmutableList.of(new Reference(BIGINT, "suppkey"))), new Constant(BIGINT, 0L))))), + "hash_2", expression(new Call(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new Coalesce(new Call(HASH_CODE, ImmutableList.of(new Reference(BIGINT, "partkey"))), new Constant(BIGINT, 0L)))))), node(MarkDistinctNode.class, tableScan("lineitem", ImmutableMap.of("suppkey", "suppkey", "partkey", "partkey")))))))))); } @@ -1919,8 +1919,8 @@ public void testRedundantHashRemovalForUnionAllAndMarkDistinct() exchange(LOCAL, REPARTITION, exchange(REMOTE, REPARTITION, project(ImmutableMap.of( - "hash_custkey", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference(BIGINT, "custkey"))), new Constant(BIGINT, 0L))))), - "hash_nationkey", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference(BIGINT, "nationkey"))), new Constant(BIGINT, 0L)))))), + "hash_custkey", expression(new Call(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new Coalesce(new Call(HASH_CODE, ImmutableList.of(new Reference(BIGINT, "custkey"))), new Constant(BIGINT, 0L))))), + "hash_nationkey", expression(new Call(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new Coalesce(new Call(HASH_CODE, ImmutableList.of(new Reference(BIGINT, "nationkey"))), new Constant(BIGINT, 0L)))))), tableScan("customer", ImmutableMap.of("custkey", "custkey", "nationkey", "nationkey")))), exchange(REMOTE, REPARTITION, node(ProjectNode.class, @@ -1939,14 +1939,14 @@ public void testRemoveRedundantFilter() join(INNER, builder -> builder .equiCriteria("ORDER_STATUS", "expr") .left( - filter(TRUE_LITERAL, + filter(TRUE, strictConstrainedTableScan( "orders", ImmutableMap.of("ORDER_STATUS", "orderstatus", "ORDER_KEY", "orderkey"), ImmutableMap.of("orderstatus", multipleValues(createVarcharType(1), ImmutableList.of(utf8Slice("F"), utf8Slice("O"))))))) .right( filter( - new InPredicate(new SymbolReference(VARCHAR, "expr"), ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("F")), new Constant(createVarcharType(1), utf8Slice("O")))), + new In(new Reference(VARCHAR, "expr"), ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("F")), new Constant(createVarcharType(1), utf8Slice("O")))), values( ImmutableList.of("expr"), ImmutableList.of(ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("O"))), ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("F")))))))))); @@ -1957,13 +1957,13 @@ public void testRemoveRedundantCrossJoin() { assertPlan("SELECT regionkey FROM nation, (SELECT 1 as a) temp WHERE regionkey = temp.a", output( - filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 1L)), + filter(new Comparison(EQUAL, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 1L)), tableScan("nation", ImmutableMap.of("REGIONKEY", "regionkey"))))); assertPlan("SELECT regionkey FROM (SELECT 1 as a) temp, nation WHERE regionkey > temp.a", output( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 1L)), tableScan("nation", ImmutableMap.of("REGIONKEY", "regionkey"))))); assertPlan("SELECT * FROM nation, (SELECT 1 as a) temp WHERE regionkey = a", @@ -1971,7 +1971,7 @@ public void testRemoveRedundantCrossJoin() project( ImmutableMap.of("expr", expression(new Constant(INTEGER, 1L))), filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 1L)), + new Comparison(EQUAL, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 1L)), tableScan("nation", ImmutableMap.of("REGIONKEY", "regionkey")))))); } @@ -2026,7 +2026,7 @@ public void testMergeProjectWithValues() "JOIN (SELECT '' || x FROM (VALUES 'F') t(x)) t2(s) " + "ON orders.orderstatus = t2.s", any(project( - ImmutableMap.of("cast", expression(new Cast(new SymbolReference(createVarcharType(1), "ORDER_STATUS"), VARCHAR))), + ImmutableMap.of("cast", expression(new Cast(new Reference(createVarcharType(1), "ORDER_STATUS"), VARCHAR))), strictConstrainedTableScan( "orders", ImmutableMap.of("ORDER_STATUS", "orderstatus", "ORDER_KEY", "orderkey"), @@ -2042,14 +2042,14 @@ public void testMergeProjectWithValues() join(INNER, builder -> builder .equiCriteria("ORDER_STATUS", "expr") .left( - filter(TRUE_LITERAL, + filter(TRUE, strictConstrainedTableScan( "orders", ImmutableMap.of("ORDER_STATUS", "orderstatus", "ORDER_KEY", "orderkey"), ImmutableMap.of("orderstatus", multipleValues(createVarcharType(1), ImmutableList.of(utf8Slice("F"), utf8Slice("O"))))))) .right( filter( - new InPredicate(new SymbolReference(VARCHAR, "expr"), ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("F")), new Constant(createVarcharType(1), utf8Slice("O")))), + new In(new Reference(VARCHAR, "expr"), ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("F")), new Constant(createVarcharType(1), utf8Slice("O")))), values(ImmutableList.of("expr"), ImmutableList.of(ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("O"))), ImmutableList.of(new Constant(createVarcharType(1), utf8Slice("F")))))))))); // Constraint for the table is derived, based on constant values in the other branch of the join. @@ -2064,14 +2064,14 @@ public void testMergeProjectWithValues() .equiCriteria("ORDER_KEY", "expr") .left( filter( - new InPredicate(new SymbolReference(BIGINT, "ORDER_KEY"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), + new In(new Reference(BIGINT, "ORDER_KEY"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), strictConstrainedTableScan( "orders", ImmutableMap.of("ORDER_STATUS", "orderstatus", "ORDER_KEY", "orderkey"), ImmutableMap.of()))) .right( filter( - new InPredicate(new SymbolReference(BIGINT, "expr"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), + new In(new Reference(BIGINT, "expr"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), values(ImmutableList.of("expr"), ImmutableList.of(ImmutableList.of(new Constant(BIGINT, 1L)), ImmutableList.of(new Constant(BIGINT, 2L))))))))); } @@ -2241,7 +2241,7 @@ public void testDoNotPlanUnreferencedRowPatternMeasures() .specification(specification(ImmutableList.of(), ImmutableList.of("id"), ImmutableMap.of("id", ASC_NULLS_LAST))) .addMeasure( "val", - new SymbolReference(INTEGER, "val"), + new Reference(INTEGER, "val"), ImmutableMap.of("val", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol(UNKNOWN, "value"))), @@ -2249,7 +2249,7 @@ public void testDoNotPlanUnreferencedRowPatternMeasures() .rowsPerMatch(WINDOW) .frame(ROWS_FROM_CURRENT) .pattern(new IrQuantified(new IrLabel("A"), oneOrMore(true))) - .addVariableDefinition(new IrLabel("A"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("A"), TRUE), values( ImmutableList.of("id", "value"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 90L)))))))); @@ -2275,7 +2275,7 @@ public void testDoNotPlanUnreferencedRowPatternMeasures() .rowsPerMatch(WINDOW) .frame(ROWS_FROM_CURRENT) .pattern(new IrQuantified(new IrLabel("A"), oneOrMore(true))) - .addVariableDefinition(new IrLabel("A"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("A"), TRUE), values( ImmutableList.of("id", "value"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 90L)))))))); @@ -2306,7 +2306,7 @@ public void testPruneUnreferencedRowPatternWindowFunctions() .rowsPerMatch(WINDOW) .frame(ROWS_FROM_CURRENT) .pattern(new IrQuantified(new IrLabel("A"), oneOrMore(true))) - .addVariableDefinition(new IrLabel("A"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("A"), TRUE), values( ImmutableList.of("id", "value"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 90L)))))))); @@ -2335,7 +2335,7 @@ public void testPruneUnreferencedRowPatternMeasures() .specification(specification(ImmutableList.of(), ImmutableList.of("id"), ImmutableMap.of("id", ASC_NULLS_LAST))) .addMeasure( "val", - new SymbolReference(INTEGER, "val"), + new Reference(INTEGER, "val"), ImmutableMap.of("val", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol(UNKNOWN, "value"))), @@ -2343,7 +2343,7 @@ public void testPruneUnreferencedRowPatternMeasures() .rowsPerMatch(WINDOW) .frame(ROWS_FROM_CURRENT) .pattern(new IrQuantified(new IrLabel("A"), oneOrMore(true))) - .addVariableDefinition(new IrLabel("A"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("A"), TRUE), values( ImmutableList.of("id", "value"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 90L)))))))); @@ -2371,14 +2371,14 @@ public void testMergePatternRecognitionNodes() .specification(specification(ImmutableList.of(), ImmutableList.of("id"), ImmutableMap.of("id", ASC_NULLS_LAST))) .addMeasure( "val", - new SymbolReference(INTEGER, "val"), + new Reference(INTEGER, "val"), ImmutableMap.of("val", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol(UNKNOWN, "value"))), INTEGER) .addMeasure( "label", - new SymbolReference(VARCHAR, "classy"), + new Reference(VARCHAR, "classy"), ImmutableMap.of("classy", new ClassifierValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0))), VARCHAR) @@ -2389,7 +2389,7 @@ public void testMergePatternRecognitionNodes() .rowsPerMatch(WINDOW) .frame(ROWS_FROM_CURRENT) .pattern(new IrQuantified(new IrLabel("A"), oneOrMore(true))) - .addVariableDefinition(new IrLabel("A"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("A"), TRUE), values( ImmutableList.of("id", "value"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 90L)))))))); @@ -2414,28 +2414,28 @@ public void testMergePatternRecognitionNodesWithProjections() output( project( ImmutableMap.of( - "output1", expression(new SymbolReference(INTEGER, "id")), - "output2", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "value"), new Constant(INTEGER, 2L))), - "output3", expression(new FunctionCall(LOWER, ImmutableList.of(new SymbolReference(VARCHAR, "label")))), - "output4", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "min"), new Constant(INTEGER, 1L)))), + "output1", expression(new Reference(INTEGER, "id")), + "output2", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "value"), new Constant(INTEGER, 2L))), + "output3", expression(new Call(LOWER, ImmutableList.of(new Reference(VARCHAR, "label")))), + "output4", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "min"), new Constant(INTEGER, 1L)))), project( ImmutableMap.of( - "id", expression(new SymbolReference(INTEGER, "id")), - "value", expression(new SymbolReference(INTEGER, "value")), - "label", expression(new SymbolReference(VARCHAR, "label")), - "min", expression(new SymbolReference(INTEGER, "min"))), + "id", expression(new Reference(INTEGER, "id")), + "value", expression(new Reference(INTEGER, "value")), + "label", expression(new Reference(VARCHAR, "label")), + "min", expression(new Reference(INTEGER, "min"))), patternRecognition(builder -> builder .specification(specification(ImmutableList.of(), ImmutableList.of("id"), ImmutableMap.of("id", ASC_NULLS_LAST))) .addMeasure( "value", - new SymbolReference(INTEGER, "value"), + new Reference(INTEGER, "value"), ImmutableMap.of("value", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol(UNKNOWN, "input2"))), INTEGER) .addMeasure( "label", - new SymbolReference(VARCHAR, "classy"), + new Reference(VARCHAR, "classy"), ImmutableMap.of("classy", new ClassifierValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0))), VARCHAR) @@ -2446,7 +2446,7 @@ public void testMergePatternRecognitionNodesWithProjections() .rowsPerMatch(WINDOW) .frame(ROWS_FROM_CURRENT) .pattern(new IrQuantified(new IrLabel("A"), oneOrMore(true))) - .addVariableDefinition(new IrLabel("A"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("A"), TRUE), values( ImmutableList.of("id", "input1", "input2"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L), new Constant(INTEGER, 3L))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java index 7d85f0d10a49..d78811bd436e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java @@ -40,11 +40,11 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeParameter; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.testing.PlanTester; import io.trino.testing.TestingAccessControlManager; @@ -65,8 +65,8 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -321,8 +321,8 @@ public void testMaterializedViewWithCasts() anyTree( project( ImmutableMap.of( - "A_CAST", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Cast(new SymbolReference(BIGINT, "A"), BIGINT), new Constant(BIGINT, 1L))), - "B_CAST", expression(new Cast(new SymbolReference(BIGINT, "B"), BIGINT))), + "A_CAST", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Reference(BIGINT, "A"), BIGINT), new Constant(BIGINT, 1L))), + "B_CAST", expression(new Cast(new Reference(BIGINT, "B"), BIGINT))), tableScan("storage_table_with_casts", ImmutableMap.of("A", "a", "B", "b"))))); } @@ -334,8 +334,8 @@ public void testRefreshMaterializedViewWithCasts() tableWriter(List.of("A_CAST", "B_CAST"), List.of("a", "b"), exchange(LOCAL, project(Map.of( - "A_CAST", expression(new Cast(new SymbolReference(BIGINT, "A"), TINYINT)), - "B_CAST", expression(new Cast(new SymbolReference(BIGINT, "B"), VARCHAR))), + "A_CAST", expression(new Cast(new Reference(BIGINT, "A"), TINYINT)), + "B_CAST", expression(new Cast(new Reference(BIGINT, "B"), VARCHAR))), tableScan("test_table", Map.of("A", "a", "B", "b"))))))); // No-op REFRESH @@ -349,9 +349,9 @@ public void testMaterializedViewWithTimestamp() { assertPlan("SELECT * FROM timestamp_mv_test WHERE ts < TIMESTAMP '2024-01-01 00:00:00.000 America/New_York'", anyTree( - project(ImmutableMap.of("ts_0", expression(new Cast(new SymbolReference(TIMESTAMP_TZ_MILLIS, "ts"), TIMESTAMP_TZ_MILLIS))), + project(ImmutableMap.of("ts_0", expression(new Cast(new Reference(TIMESTAMP_TZ_MILLIS, "ts"), TIMESTAMP_TZ_MILLIS))), filter( - new ComparisonExpression(LESS_THAN, new Cast(new SymbolReference(TIMESTAMP_TZ_MILLIS, "ts"), TIMESTAMP_TZ_MILLIS), new Constant(createTimestampWithTimeZoneType(3), DateTimes.parseTimestampWithTimeZone(3, "2024-01-01 00:00:00.000 America/New_York"))), + new Comparison(LESS_THAN, new Cast(new Reference(TIMESTAMP_TZ_MILLIS, "ts"), TIMESTAMP_TZ_MILLIS), new Constant(createTimestampWithTimeZoneType(3), DateTimes.parseTimestampWithTimeZone(3, "2024-01-01 00:00:00.000 America/New_York"))), tableScan("timestamp_test_storage", ImmutableMap.of("ts", "ts", "id", "id")))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 584067af50aa..cd528cda5d29 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -21,13 +21,13 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.NodeRef; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.transaction.TransactionId; import org.junit.jupiter.api.Test; @@ -38,7 +38,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.PartialTranslator.extractPartialTranslations; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; @@ -57,18 +57,18 @@ public class TestPartialTranslator @Test public void testPartialTranslator() { - Expression rowSymbolReference = new SymbolReference(RowType.anonymousRow(INTEGER, INTEGER), "row_symbol_1"); - Expression dereferenceExpression1 = new SubscriptExpression(INTEGER, rowSymbolReference, new Constant(INTEGER, 1L)); - Expression dereferenceExpression2 = new SubscriptExpression(INTEGER, rowSymbolReference, new Constant(INTEGER, 2L)); + Expression rowSymbolReference = new Reference(RowType.anonymousRow(INTEGER, INTEGER), "row_symbol_1"); + Expression dereferenceExpression1 = new Subscript(INTEGER, rowSymbolReference, new Constant(INTEGER, 1L)); + Expression dereferenceExpression2 = new Subscript(INTEGER, rowSymbolReference, new Constant(INTEGER, 2L)); Expression stringLiteral = new Constant(VARCHAR, Slices.utf8Slice("abcd")); - Expression symbolReference1 = new SymbolReference(DOUBLE, "double_symbol_1"); + Expression symbolReference1 = new Reference(DOUBLE, "double_symbol_1"); assertFullTranslation(symbolReference1); assertFullTranslation(dereferenceExpression1); assertFullTranslation(stringLiteral); - assertFullTranslation(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, symbolReference1, dereferenceExpression1)); + assertFullTranslation(new Arithmetic(ADD_INTEGER, ADD, symbolReference1, dereferenceExpression1)); - Expression functionCallExpression = new FunctionCall( + Expression functionCallExpression = new Call( PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("concat", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(stringLiteral, dereferenceExpression2)); assertFullTranslation(functionCallExpression); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java index 74c5dfd5d46a..c6fef26f96e4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanMatchingFramework.java @@ -19,10 +19,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; @@ -35,7 +35,7 @@ import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.columnReference; @@ -124,7 +124,7 @@ public void testAliasExpressionFromProject() { assertMinimallyOptimizedPlan("SELECT orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), - project(ImmutableMap.of("EXPRESSION", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new SymbolReference(BIGINT, "ORDERKEY")))), + project(ImmutableMap.of("EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey"))))); } @@ -134,8 +134,8 @@ public void testStrictProject() assertMinimallyOptimizedPlan("SELECT orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), strictProject(ImmutableMap.of( - "EXPRESSION", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new SymbolReference(BIGINT, "ORDERKEY"))), - "ORDERKEY", expression(new SymbolReference(BIGINT, "ORDERKEY"))), + "EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY"))), + "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY"))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey"))))); } @@ -145,8 +145,8 @@ public void testIdentityAliasFromProject() assertMinimallyOptimizedPlan("SELECT orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), project(ImmutableMap.of( - "ORDERKEY", expression(new SymbolReference(BIGINT, "ORDERKEY")), - "EXPRESSION", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new SymbolReference(BIGINT, "ORDERKEY")))), + "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY")), + "EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey"))))); } @@ -253,7 +253,7 @@ public void testStrictProjectExtraSymbols() { assertThatThrownBy(() -> assertMinimallyOptimizedPlan("SELECT discount, orderkey, 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY", "EXPRESSION"), - strictProject(ImmutableMap.of("EXPRESSION", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Constant(BIGINT, 1L), new SymbolReference(BIGINT, "ORDERKEY"))), "ORDERKEY", expression(new SymbolReference(BIGINT, "ORDERKEY"))), + strictProject(ImmutableMap.of("EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Constant(BIGINT, 1L), new Reference(BIGINT, "ORDERKEY"))), "ORDERKEY", expression(new Reference(BIGINT, "ORDERKEY"))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey")))))) .isInstanceOf(AssertionError.class) .hasMessageStartingWith("Plan does not match"); @@ -283,7 +283,7 @@ public void testProjectLimitsScope() { assertThatThrownBy(() -> assertMinimallyOptimizedPlan("SELECT 1 + orderkey FROM lineitem", output(ImmutableList.of("ORDERKEY"), - project(ImmutableMap.of("EXPRESSION", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new SymbolReference(BIGINT, "ORDERKEY")))), + project(ImmutableMap.of("EXPRESSION", expression(new Arithmetic(ADD_BIGINT, ADD, new Cast(new Constant(INTEGER, 1L), BIGINT), new Reference(BIGINT, "ORDERKEY")))), tableScan("lineitem", ImmutableMap.of("ORDERKEY", "orderkey")))))) .isInstanceOf(IllegalStateException.class) .hasMessageMatching("missing expression for alias .*"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java index 10f23c19c0e8..828063de2f78 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java @@ -19,13 +19,13 @@ import io.trino.Session; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.plan.ExchangeNode; import org.junit.jupiter.api.Test; @@ -35,7 +35,7 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -78,13 +78,13 @@ public void testCoercions() .left( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(4), "t_v"), createVarcharType(4))), + new Comparison(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(4), "t_v"), createVarcharType(4))), tableScan("nation", ImmutableMap.of("t_k", "nationkey", "t_v", "name"))))) .right( anyTree( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(4), "u_v"), createVarcharType(4))), + new Comparison(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(4), "u_v"), createVarcharType(4))), tableScan("nation", ImmutableMap.of("u_k", "nationkey", "u_v", "name"))))))))); // values have different types (varchar(4) vs varchar(5)) in each table @@ -102,13 +102,13 @@ public void testCoercions() .left( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(4), "t_v"), createVarcharType(4))), + new Comparison(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(4), "t_v"), createVarcharType(4))), tableScan("nation", ImmutableMap.of("t_k", "nationkey", "t_v", "name"))))) .right( anyTree( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(5), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(5), "u_v"), createVarcharType(5))), + new Comparison(EQUAL, new Constant(createVarcharType(5), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(5), "u_v"), createVarcharType(5))), tableScan("nation", ImmutableMap.of("u_k", "nationkey", "u_v", "name"))))))))); } @@ -135,7 +135,7 @@ public void testNormalizeOuterJoinToInner() .right( anyTree( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_name"))), + new Not(new IsNull(new Reference(VARCHAR, "c_name"))), tableScan("customer", ImmutableMap.of("c_custkey", "custkey", "c_name", "name")))))))); // nested joins @@ -162,7 +162,7 @@ public void testNormalizeOuterJoinToInner() .right( anyTree( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_name"))), + new Not(new IsNull(new Reference(VARCHAR, "c_name"))), tableScan("customer", ImmutableMap.of("c_custkey", "custkey", "c_name", "name")))))))); } @@ -178,7 +178,7 @@ public void testNonDeterministicPredicateDoesNotPropagateFromFilteringSideToSour "LINE_ORDER_KEY", "orderkey"))), node(ExchangeNode.class, filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new Comparison(EQUAL, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Cast(new Call(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -192,7 +192,7 @@ public void testNonStraddlingJoinExpression() .equiCriteria("LINEITEM_OK", "ORDERS_OK") .left( filter( - new ComparisonExpression(EQUAL, new Cast(new SymbolReference(INTEGER, "LINEITEM_LINENUMBER"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("2"))), + new Comparison(EQUAL, new Cast(new Reference(INTEGER, "LINEITEM_LINENUMBER"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("2"))), tableScan("lineitem", ImmutableMap.of( "LINEITEM_OK", "orderkey", "LINEITEM_LINENUMBER", "linenumber")))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java index 5eb78247ce11..c5c7e7688350 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java @@ -19,13 +19,13 @@ import io.trino.Session; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.plan.ExchangeNode; import org.junit.jupiter.api.Test; @@ -35,7 +35,7 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -78,13 +78,13 @@ public void testCoercions() .left( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(4), "t_v"), createVarcharType(4))), + new Comparison(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(4), "t_v"), createVarcharType(4))), tableScan("nation", ImmutableMap.of("t_k", "nationkey", "t_v", "name"))))) .right( anyTree( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(5), "u_v"), createVarcharType(4))), + new Comparison(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(5), "u_v"), createVarcharType(4))), tableScan("nation", ImmutableMap.of("u_k", "nationkey", "u_v", "name"))))))))); // values have different types (varchar(4) vs varchar(5)) in each table @@ -101,13 +101,13 @@ public void testCoercions() .left( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(4), "t_v"), createVarcharType(4))), + new Comparison(EQUAL, new Constant(createVarcharType(4), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(4), "t_v"), createVarcharType(4))), tableScan("nation", ImmutableMap.of("t_k", "nationkey", "t_v", "name"))))) .right( anyTree( project( filter( - new ComparisonExpression(EQUAL, new Constant(createVarcharType(5), Slices.utf8Slice("x")), new Cast(new SymbolReference(createVarcharType(5), "u_v"), createVarcharType(5))), + new Comparison(EQUAL, new Constant(createVarcharType(5), Slices.utf8Slice("x")), new Cast(new Reference(createVarcharType(5), "u_v"), createVarcharType(5))), tableScan("nation", ImmutableMap.of("u_k", "nationkey", "u_v", "name"))))))))); } @@ -133,7 +133,7 @@ public void testNormalizeOuterJoinToInner() .right( anyTree( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_name"))), + new Not(new IsNull(new Reference(VARCHAR, "c_name"))), tableScan("customer", ImmutableMap.of("c_custkey", "custkey", "c_name", "name")))))))); // nested joins @@ -159,7 +159,7 @@ public void testNormalizeOuterJoinToInner() .right( anyTree( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_name"))), + new Not(new IsNull(new Reference(VARCHAR, "c_name"))), tableScan("customer", ImmutableMap.of("c_custkey", "custkey", "c_name", "name")))))))); } @@ -174,7 +174,7 @@ public void testNonDeterministicPredicateDoesNotPropagateFromFilteringSideToSour "LINE_ORDER_KEY", "orderkey")), node(ExchangeNode.class, filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "ORDERS_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new Comparison(EQUAL, new Reference(BIGINT, "ORDERS_ORDER_KEY"), new Cast(new Call(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } @@ -188,7 +188,7 @@ public void testNonStraddlingJoinExpression() .equiCriteria("LINEITEM_OK", "ORDERS_OK") .left( filter( - new ComparisonExpression(EQUAL, new Cast(new SymbolReference(INTEGER, "LINEITEM_LINENUMBER"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("2"))), + new Comparison(EQUAL, new Cast(new Reference(INTEGER, "LINEITEM_LINENUMBER"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("2"))), tableScan("lineitem", ImmutableMap.of( "LINEITEM_OK", "orderkey", "LINEITEM_LINENUMBER", "linenumber")))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestQuantifiedComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestQuantifiedComparison.java index acc8f6213174..d102584a7ee1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestQuantifiedComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestQuantifiedComparison.java @@ -14,8 +14,8 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.JoinNode; @@ -52,7 +52,7 @@ public void testQuantifiedComparisonNotEqualsAll() String query = "SELECT orderkey, custkey FROM orders WHERE orderkey <> ALL (VALUES ROW(CAST(5 as BIGINT)), ROW(CAST(3 as BIGINT)))"; assertPlan(query, anyTree( filter( - new NotExpression(new SymbolReference(BOOLEAN, "S")), + new Not(new Reference(BOOLEAN, "S")), semiJoin("X", "Y", "S", tableScan("orders", ImmutableMap.of("X", "orderkey")), values(ImmutableMap.of("Y", 0)))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java index 6a038bb5f48c..77bbd30eb050 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java @@ -20,12 +20,12 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -38,10 +38,10 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -90,32 +90,32 @@ public void testRecursiveQuery() values()))), // first recursion step project(project(project( - ImmutableMap.of("expr_0", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_0", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 6L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "expr"), new Constant(INTEGER, 6L)), project(project(project( ImmutableMap.of("expr", expression(new Constant(INTEGER, 1L))), values()))))))), // "post-recursion" step with convergence assertion filter( ifExpression( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), - new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) NOT_SUPPORTED.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Recursion depth limit exceeded (1). Use 'max_recursion_depth' session property to modify the limit.")))), BOOLEAN), - TRUE_LITERAL), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) NOT_SUPPORTED.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Recursion depth limit exceeded (1). Use 'max_recursion_depth' session property to modify the limit.")))), BOOLEAN), + TRUE), window(windowBuilder -> windowBuilder .addFunction( "count", windowFunction("count", ImmutableList.of(), DEFAULT_FRAME)), project(project(project( - ImmutableMap.of("expr_1", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 6L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "expr"), new Constant(INTEGER, 6L)), project( - ImmutableMap.of("expr", expression(new SymbolReference(INTEGER, "expr_0"))), + ImmutableMap.of("expr", expression(new Reference(INTEGER, "expr_0"))), project(project(project( - ImmutableMap.of("expr_0", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_0", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "expr"), new Constant(INTEGER, 2L)))), filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 6L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "expr"), new Constant(INTEGER, 6L)), project(project(project( ImmutableMap.of("expr", expression(new Constant(INTEGER, 1L))), values())))))))))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestRemoveDuplicatePredicates.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestRemoveDuplicatePredicates.java index 0f4e6942584d..aabf0494e4d3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestRemoveDuplicatePredicates.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestRemoveDuplicatePredicates.java @@ -13,14 +13,14 @@ */ package io.trino.sql.planner; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -35,7 +35,7 @@ public void testAnd() "SELECT * FROM (VALUES 1) t(a) WHERE a = 1 AND 1 = a AND a = 1", anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "A"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "A"), new Constant(INTEGER, 1L)), values("A")))); } @@ -46,7 +46,7 @@ public void testOr() "SELECT * FROM (VALUES 1) t(a) WHERE a = 1 OR 1 = a OR a = 1", anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "A"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "A"), new Constant(INTEGER, 1L)), values("A")))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestSimplifyIn.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestSimplifyIn.java index 41992baf65e5..4d1084d18839 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestSimplifyIn.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestSimplifyIn.java @@ -13,14 +13,14 @@ */ package io.trino.sql.planner; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -35,7 +35,7 @@ public void testInWithSingleElement() "SELECT * FROM (VALUES 0) t(a) WHERE a IN (5)", anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "A"), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Reference(INTEGER, "A"), new Constant(INTEGER, 5L)), values("A")))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java index 7e4e462af979..ab38f2e1f8aa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java @@ -18,14 +18,14 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -37,14 +37,14 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.extractConjuncts; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.type.UnknownType.UNKNOWN; import static org.assertj.core.api.Assertions.assertThat; @@ -61,84 +61,84 @@ public class TestSortExpressionExtractor public void testGetSortExpression() { assertGetSortExpression( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "b1")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1")), "b1"); assertGetSortExpression( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), "b2"); assertGetSortExpression( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), "b2"); assertGetSortExpression( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new FunctionCall(SIN, ImmutableList.of(new SymbolReference(BIGINT, "p1")))), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Call(SIN, ImmutableList.of(new Reference(BIGINT, "p1")))), "b2"); - assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new FunctionCall(RANDOM, ImmutableList.of(new SymbolReference(BIGINT, "p1"))))); + assertNoSortExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Call(RANDOM, ImmutableList.of(new Reference(BIGINT, "p1"))))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new FunctionCall(RANDOM, ImmutableList.of(new SymbolReference(BIGINT, "p1")))), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Call(RANDOM, ImmutableList.of(new Reference(BIGINT, "p1")))), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")))), "b2", - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1"))); + new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1"))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new FunctionCall(RANDOM, ImmutableList.of(new SymbolReference(BIGINT, "p1")))), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Call(RANDOM, ImmutableList.of(new Reference(BIGINT, "p1")))), new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")))), "b1", - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1"))); + new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"))); - assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "p1"), new SymbolReference(INTEGER, "b2")))); + assertNoSortExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "p1"), new Reference(INTEGER, "b2")))); - assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN, ImmutableList.of(new SymbolReference(BIGINT, "b1"))), new SymbolReference(BIGINT, "p1"))); + assertNoSortExpression(new Comparison(GREATER_THAN, new Call(SIN, ImmutableList.of(new Reference(BIGINT, "b1"))), new Reference(BIGINT, "p1"))); - assertNoSortExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1"))))); + assertNoSortExpression(new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1"))))); - assertNoSortExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN, ImmutableList.of(new SymbolReference(BIGINT, "b2"))), new SymbolReference(BIGINT, "p1")), new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "p1"), new Constant(INTEGER, 10L)))))))); + assertNoSortExpression(new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Call(SIN, ImmutableList.of(new Reference(BIGINT, "b2"))), new Reference(BIGINT, "p1")), new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "p1"), new Constant(INTEGER, 10L)))))))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN, ImmutableList.of(new SymbolReference(BIGINT, "b2"))), new SymbolReference(BIGINT, "p1")), new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "p1"), new Constant(INTEGER, 10L))))))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Call(SIN, ImmutableList.of(new Reference(BIGINT, "b2"))), new Reference(BIGINT, "p1")), new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "p1"), new Constant(INTEGER, 10L))))))), "b2", - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "p1"), new Constant(INTEGER, 10L)))); + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "p1"), new Constant(INTEGER, 10L)))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")))), "b1"); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")))), "b1", - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1"))); + new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "b2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "p1"), new Constant(INTEGER, 10L))), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p2")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "p1"), new Constant(INTEGER, 10L))), new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p2")))), "b2", - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p1")), - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "b2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "p1"), new Constant(INTEGER, 10L))), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b2"), new SymbolReference(BIGINT, "p2"))); + new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p1")), + new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "p1"), new Constant(INTEGER, 10L))), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p2"))); assertGetSortExpression( - new BetweenPredicate(new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "b2")), + new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"), new Reference(BIGINT, "b2")), "b1", - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "b1"))); + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"))); assertGetSortExpression( - new BetweenPredicate(new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "p2"), new SymbolReference(BIGINT, "b1")), + new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "p2"), new Reference(BIGINT, "b1")), "b1", - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "b1"))); + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"))); assertGetSortExpression( - new BetweenPredicate(new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "p2")), + new Between(new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"), new Reference(BIGINT, "p2")), "b1", - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1"))); + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), new BetweenPredicate(new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "b2")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"), new Reference(BIGINT, "b2")))), "b1", - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b1"), new SymbolReference(BIGINT, "p1")), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "p1"), new SymbolReference(BIGINT, "b1"))); + new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"))); } private void assertNoSortExpression(Expression expression) @@ -160,7 +160,7 @@ private void assertGetSortExpression(Expression expression, String expectedSymbo private void assertGetSortExpression(Expression expression, String expectedSymbol, List searchExpressions) { - Optional expected = Optional.of(new SortExpressionContext(new SymbolReference(BIGINT, expectedSymbol), searchExpressions)); + Optional expected = Optional.of(new SortExpressionContext(new Reference(BIGINT, expectedSymbol), searchExpressions)); Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertThat(actual).isEqualTo(expected); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java index 62af8e16f34c..fcf9c1d63cfc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java @@ -30,7 +30,7 @@ import io.trino.spi.function.table.Descriptor.Field; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.plan.TableFunctionProcessorNode; @@ -45,7 +45,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -160,7 +160,7 @@ public void testTableFunctionInitialPlanWithCoercionForCopartitioning() .passThroughSymbols(ImmutableSet.of("c2"))) .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) .properOutputs(ImmutableList.of("COLUMN")), - project(ImmutableMap.of("c1_coerced", expression(new Cast(new SymbolReference(SMALLINT, "c1"), INTEGER))), + project(ImmutableMap.of("c1_coerced", expression(new Cast(new Reference(SMALLINT, "c1"), INTEGER))), anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new Constant(SMALLINT, 1L)))))), anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 2L)))))))); } @@ -214,7 +214,7 @@ public void testPruneTableFunctionColumns() .passThroughSymbols(ImmutableList.of(ImmutableList.of("a", "b"))) .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), - values(ImmutableList.of("a", "b"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), TRUE_LITERAL)))))); + values(ImmutableList.of("a", "b"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L), TRUE)))))); // no table function outputs are referenced. All pass-through symbols are pruned from the TableFunctionProcessorNode. The unused symbol "b" is pruned from the source values node. assertPlan("SELECT 'constant' c FROM TABLE(mock.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java index 83f179432fab..c609047c747a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java @@ -40,9 +40,9 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.PlanAssert; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -63,8 +63,8 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.output; @@ -149,7 +149,7 @@ public void testRedirectionAfterProjectionPushdown() output( ImmutableList.of("DEST_COL"), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "DEST_COL"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "DEST_COL"), new Constant(INTEGER, 0L)), tableScan( new MockConnectorTableHandle(DESTINATION_TABLE)::equals, TupleDomain.all(), @@ -206,7 +206,7 @@ public void testPredicatePushdownAfterRedirect() ImmutableList.of("DEST_COL_A", "DEST_COL_B"), filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "DEST_COL_A"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "DEST_COL_A"), new Constant(INTEGER, 1L)), tableScan( new MockConnectorTableHandle( DESTINATION_TABLE, @@ -261,9 +261,9 @@ public void testPredicateTypeWithCoercion() "SELECT source_col_b FROM test_table WHERE source_col_c = 'foo'", output( ImmutableList.of("DEST_COL_B"), - project(ImmutableMap.of("DEST_COL_B", expression(new SymbolReference(BIGINT, "DEST_COL_B"))), + project(ImmutableMap.of("DEST_COL_B", expression(new Reference(BIGINT, "DEST_COL_B"))), filter( - new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "DEST_COL_A"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("foo"))), + new Comparison(EQUAL, new Cast(new Reference(BIGINT, "DEST_COL_A"), VARCHAR), new Constant(VARCHAR, Slices.utf8Slice("foo"))), tableScan( new MockConnectorTableHandle( DESTINATION_TABLE, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java index f9a07eba5119..85a4e25a9d43 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java @@ -19,7 +19,7 @@ import io.airlift.slice.Slices; import io.trino.cost.StatsAndCosts; import io.trino.operator.RetryPolicy; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Row; import io.trino.sql.planner.plan.IndexJoinNode; @@ -171,7 +171,7 @@ private static SpatialJoinNode spatialJoin(String id, PlanNode left, PlanNode ri left, right, left.getOutputSymbols(), - BooleanLiteral.TRUE_LITERAL, + Booleans.TRUE, Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java index 353892547394..a8ee5824d5a3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java @@ -19,15 +19,15 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.TimeZoneKey; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.type.DateTimes; import io.trino.util.DateTimeUtils; @@ -48,15 +48,15 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.output; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -75,276 +75,276 @@ public class TestUnwrapCastInComparison public void testEquals() { // representable - testUnwrap("smallint", "a = DOUBLE '1'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a = DOUBLE '1'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a = DOUBLE '1'", new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a = DOUBLE '1'", new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable - testUnwrap("smallint", "a = DOUBLE '1.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("smallint", "a = DOUBLE '1.9'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = DOUBLE '1.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = DOUBLE '1.9'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("bigint", "a = DOUBLE '1.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a = DOUBLE '1.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); // below top of range - testUnwrap("smallint", "a = DOUBLE '32766'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a = DOUBLE '32766'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range - testUnwrap("smallint", "a = DOUBLE '32766.9'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = DOUBLE '32766.9'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // top of range - testUnwrap("smallint", "a = DOUBLE '32767'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a = DOUBLE '32767'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // above range - testUnwrap("smallint", "a = DOUBLE '32768.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = DOUBLE '32768.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // above bottom of range - testUnwrap("smallint", "a = DOUBLE '-32767'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a = DOUBLE '-32767'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range - testUnwrap("smallint", "a = DOUBLE '-32767.9'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = DOUBLE '-32767.9'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // bottom of range - testUnwrap("smallint", "a = DOUBLE '-32768'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a = DOUBLE '-32768'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // below range - testUnwrap("smallint", "a = DOUBLE '-32768.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = DOUBLE '-32768.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // -2^64 constant - testUnwrap("bigint", "a = DOUBLE '-18446744073709551616'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a = DOUBLE '-18446744073709551616'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); // varchar and char, same length - testUnwrap("varchar(3)", "a = CAST('abc' AS char(3))", new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(3), "a"), new Constant(createVarcharType(3), Slices.utf8Slice("abc")))); + testUnwrap("varchar(3)", "a = CAST('abc' AS char(3))", new Comparison(EQUAL, new Reference(createVarcharType(3), "a"), new Constant(createVarcharType(3), Slices.utf8Slice("abc")))); // longer varchar and char // actually unwrapping didn't happen - testUnwrap("varchar(10)", "a = CAST('abc' AS char(3))", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(createVarcharType(10), "a"), createCharType(10)), new Constant(createCharType(10), Slices.utf8Slice("abc")))); + testUnwrap("varchar(10)", "a = CAST('abc' AS char(3))", new Comparison(EQUAL, new Cast(new Reference(createVarcharType(10), "a"), createCharType(10)), new Constant(createCharType(10), Slices.utf8Slice("abc")))); // unbounded varchar and char // actually unwrapping didn't happen - testUnwrap("varchar", "a = CAST('abc' AS char(3))", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(VARCHAR, "a"), createCharType(65536)), new Constant(createCharType(65536), Slices.utf8Slice("abc")))); + testUnwrap("varchar", "a = CAST('abc' AS char(3))", new Comparison(EQUAL, new Cast(new Reference(VARCHAR, "a"), createCharType(65536)), new Constant(createCharType(65536), Slices.utf8Slice("abc")))); } @Test public void testNotEquals() { // representable - testUnwrap("smallint", "a <> DOUBLE '1'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a <> DOUBLE '1'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a <> DOUBLE '1'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a <> DOUBLE '1'", new Comparison(NOT_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable - testUnwrap("smallint", "a <> DOUBLE '1.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); - testUnwrap("smallint", "a <> DOUBLE '1.9'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '1.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '1.9'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); - testUnwrap("smallint", "a <> DOUBLE '1.9'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '1.9'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); - testUnwrap("bigint", "a <> DOUBLE '1.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a <> DOUBLE '1.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); // below top of range - testUnwrap("smallint", "a <> DOUBLE '32766'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a <> DOUBLE '32766'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range - testUnwrap("smallint", "a <> DOUBLE '32766.9'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '32766.9'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // top of range - testUnwrap("smallint", "a <> DOUBLE '32767'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a <> DOUBLE '32767'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // above range - testUnwrap("smallint", "a <> DOUBLE '32768.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '32768.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // 2^64 constant - testUnwrap("bigint", "a <> DOUBLE '18446744073709551616'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a <> DOUBLE '18446744073709551616'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); // above bottom of range - testUnwrap("smallint", "a <> DOUBLE '-32767'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a <> DOUBLE '-32767'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range - testUnwrap("smallint", "a <> DOUBLE '-32767.9'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '-32767.9'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // bottom of range - testUnwrap("smallint", "a <> DOUBLE '-32768'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a <> DOUBLE '-32768'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // below range - testUnwrap("smallint", "a <> DOUBLE '-32768.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> DOUBLE '-32768.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); } @Test public void testLessThan() { // representable - testUnwrap("smallint", "a < DOUBLE '1'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a < DOUBLE '1'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a < DOUBLE '1'", new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a < DOUBLE '1'", new Comparison(LESS_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable - testUnwrap("smallint", "a < DOUBLE '1.1'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a < DOUBLE '1.1'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a < DOUBLE '1.1'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a < DOUBLE '1.1'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); - testUnwrap("smallint", "a < DOUBLE '1.9'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); + testUnwrap("smallint", "a < DOUBLE '1.9'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); // below top of range - testUnwrap("smallint", "a < DOUBLE '32766'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a < DOUBLE '32766'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range - testUnwrap("smallint", "a < DOUBLE '32766.9'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a < DOUBLE '32766.9'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // top of range - testUnwrap("smallint", "a < DOUBLE '32767'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a < DOUBLE '32767'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // above range - testUnwrap("smallint", "a < DOUBLE '32768.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a < DOUBLE '32768.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // above bottom of range - testUnwrap("smallint", "a < DOUBLE '-32767'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a < DOUBLE '-32767'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range - testUnwrap("smallint", "a < DOUBLE '-32767.9'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a < DOUBLE '-32767.9'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // bottom of range - testUnwrap("smallint", "a < DOUBLE '-32768'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a < DOUBLE '-32768'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // below range - testUnwrap("smallint", "a < DOUBLE '-32768.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a < DOUBLE '-32768.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // -2^64 constant - testUnwrap("bigint", "a < DOUBLE '-18446744073709551616'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a < DOUBLE '-18446744073709551616'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); } @Test public void testLessThanOrEqual() { // representable - testUnwrap("smallint", "a <= DOUBLE '1'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a <= DOUBLE '1'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a <= DOUBLE '1'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a <= DOUBLE '1'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable - testUnwrap("smallint", "a <= DOUBLE '1.1'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a <= DOUBLE '1.1'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a <= DOUBLE '1.1'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a <= DOUBLE '1.1'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); - testUnwrap("smallint", "a <= DOUBLE '1.9'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); + testUnwrap("smallint", "a <= DOUBLE '1.9'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); // below top of range - testUnwrap("smallint", "a <= DOUBLE '32766'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a <= DOUBLE '32766'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range - testUnwrap("smallint", "a <= DOUBLE '32766.9'", new ComparisonExpression(LESS_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a <= DOUBLE '32766.9'", new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // top of range - testUnwrap("smallint", "a <= DOUBLE '32767'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <= DOUBLE '32767'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // above range - testUnwrap("smallint", "a <= DOUBLE '32768.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <= DOUBLE '32768.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // 2^64 constant - testUnwrap("bigint", "a <= DOUBLE '18446744073709551616'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a <= DOUBLE '18446744073709551616'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); // above bottom of range - testUnwrap("smallint", "a <= DOUBLE '-32767'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a <= DOUBLE '-32767'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range - testUnwrap("smallint", "a <= DOUBLE '-32767.9'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a <= DOUBLE '-32767.9'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // bottom of range - testUnwrap("smallint", "a <= DOUBLE '-32768'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a <= DOUBLE '-32768'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // below range - testUnwrap("smallint", "a <= DOUBLE '-32768.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <= DOUBLE '-32768.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); } @Test public void testGreaterThan() { // representable - testUnwrap("smallint", "a > DOUBLE '1'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a > DOUBLE '1'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a > DOUBLE '1'", new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a > DOUBLE '1'", new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable - testUnwrap("smallint", "a > DOUBLE '1.1'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a > DOUBLE '1.1'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("smallint", "a > DOUBLE '1.9'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); + testUnwrap("smallint", "a > DOUBLE '1.9'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); - testUnwrap("bigint", "a > DOUBLE '1.9'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 2L))); + testUnwrap("bigint", "a > DOUBLE '1.9'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 2L))); // below top of range - testUnwrap("smallint", "a > DOUBLE '32766'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a > DOUBLE '32766'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range - testUnwrap("smallint", "a > DOUBLE '32766.9'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a > DOUBLE '32766.9'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // top of range - testUnwrap("smallint", "a > DOUBLE '32767'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a > DOUBLE '32767'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // above range - testUnwrap("smallint", "a > DOUBLE '32768.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a > DOUBLE '32768.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // 2^64 constant - testUnwrap("bigint", "a > DOUBLE '18446744073709551616'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a > DOUBLE '18446744073709551616'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); // above bottom of range - testUnwrap("smallint", "a > DOUBLE '-32767'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a > DOUBLE '-32767'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range - testUnwrap("smallint", "a > DOUBLE '-32767.9'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a > DOUBLE '-32767.9'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // bottom of range - testUnwrap("smallint", "a > DOUBLE '-32768'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a > DOUBLE '-32768'", new Comparison(NOT_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // below range - testUnwrap("smallint", "a > DOUBLE '-32768.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a > DOUBLE '-32768.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); } @Test public void testGreaterThanOrEqual() { // representable - testUnwrap("smallint", "a >= DOUBLE '1'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a >= DOUBLE '1'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a >= DOUBLE '1'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a >= DOUBLE '1'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable - testUnwrap("smallint", "a >= DOUBLE '1.1'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a >= DOUBLE '1.1'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a >= DOUBLE '1.1'", new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a >= DOUBLE '1.1'", new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); - testUnwrap("smallint", "a >= DOUBLE '1.9'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); + testUnwrap("smallint", "a >= DOUBLE '1.9'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 2L))); // below top of range - testUnwrap("smallint", "a >= DOUBLE '32766'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a >= DOUBLE '32766'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range - testUnwrap("smallint", "a >= DOUBLE '32766.9'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a >= DOUBLE '32766.9'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // top of range - testUnwrap("smallint", "a >= DOUBLE '32767'", new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a >= DOUBLE '32767'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // above range - testUnwrap("smallint", "a >= DOUBLE '32768.1'", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a >= DOUBLE '32768.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); // above bottom of range - testUnwrap("smallint", "a >= DOUBLE '-32767'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a >= DOUBLE '-32767'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range - testUnwrap("smallint", "a >= DOUBLE '-32767.9'", new ComparisonExpression(GREATER_THAN, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a >= DOUBLE '-32767.9'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // bottom of range - testUnwrap("smallint", "a >= DOUBLE '-32768'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a >= DOUBLE '-32768'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // below range - testUnwrap("smallint", "a >= DOUBLE '-32768.1'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a >= DOUBLE '-32768.1'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); // -2^64 constant - testUnwrap("bigint", "a >= DOUBLE '-18446744073709551616'", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a >= DOUBLE '-18446744073709551616'", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); } @Test public void testDistinctFrom() { // representable - testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '1'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '1'", new Comparison(IS_DISTINCT_FROM, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); - testUnwrap("bigint", "a IS DISTINCT FROM DOUBLE '1'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))); + testUnwrap("bigint", "a IS DISTINCT FROM DOUBLE '1'", new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))); // non-representable testRemoveFilter("smallint", "a IS DISTINCT FROM DOUBLE '1.1'"); @@ -354,13 +354,13 @@ public void testDistinctFrom() testRemoveFilter("bigint", "a IS DISTINCT FROM DOUBLE '1.9'"); // below top of range - testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '32766'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); + testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '32766'", new Comparison(IS_DISTINCT_FROM, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L))); // round to top of range testRemoveFilter("smallint", "a IS DISTINCT FROM DOUBLE '32766.9'"); // top of range - testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '32767'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '32767'", new Comparison(IS_DISTINCT_FROM, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); // above range testRemoveFilter("smallint", "a IS DISTINCT FROM DOUBLE '32768.1'"); @@ -369,13 +369,13 @@ public void testDistinctFrom() testRemoveFilter("bigint", "a IS DISTINCT FROM DOUBLE '18446744073709551616'"); // above bottom of range - testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '-32767'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); + testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '-32767'", new Comparison(IS_DISTINCT_FROM, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L))); // round to bottom of range testRemoveFilter("smallint", "a IS DISTINCT FROM DOUBLE '-32767.9'"); // bottom of range - testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '-32768'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a IS DISTINCT FROM DOUBLE '-32768'", new Comparison(IS_DISTINCT_FROM, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); // below range testRemoveFilter("smallint", "a IS DISTINCT FROM DOUBLE '-32768.1'"); @@ -398,33 +398,33 @@ public void testNull() testUnwrap("smallint", "a <= CAST(NULL AS DOUBLE)", new Constant(BOOLEAN, null)); - testUnwrap("smallint", "a IS DISTINCT FROM CAST(NULL AS DOUBLE)", new NotExpression(new IsNullPredicate(new Cast(new SymbolReference(SMALLINT, "a"), DOUBLE)))); + testUnwrap("smallint", "a IS DISTINCT FROM CAST(NULL AS DOUBLE)", new Not(new IsNull(new Cast(new Reference(SMALLINT, "a"), DOUBLE)))); - testUnwrap("bigint", "a IS DISTINCT FROM CAST(NULL AS DOUBLE)", new NotExpression(new IsNullPredicate(new Cast(new SymbolReference(BIGINT, "a"), DOUBLE)))); + testUnwrap("bigint", "a IS DISTINCT FROM CAST(NULL AS DOUBLE)", new Not(new IsNull(new Cast(new Reference(BIGINT, "a"), DOUBLE)))); } @Test public void testNaN() { - testUnwrap("smallint", "a = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a = nan()", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("bigint", "a = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("bigint", "a = nan()", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("smallint", "a < nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a < nan()", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("smallint", "a <> nan()", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("smallint", "a <> nan()", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); testRemoveFilter("smallint", "a IS DISTINCT FROM nan()"); testRemoveFilter("bigint", "a IS DISTINCT FROM nan()"); - testUnwrap("real", "a = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(REAL, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("real", "a = nan()", new Logical(AND, ImmutableList.of(new IsNull(new Reference(REAL, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("real", "a < nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(REAL, "a")), new Constant(BOOLEAN, null)))); + testUnwrap("real", "a < nan()", new Logical(AND, ImmutableList.of(new IsNull(new Reference(REAL, "a")), new Constant(BOOLEAN, null)))); - testUnwrap("real", "a <> nan()", new LogicalExpression(OR, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(REAL, "a"))), new Constant(BOOLEAN, null)))); + testUnwrap("real", "a <> nan()", new Logical(OR, ImmutableList.of(new Not(new IsNull(new Reference(REAL, "a"))), new Constant(BOOLEAN, null)))); - testUnwrap("real", "a IS DISTINCT FROM nan()", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(REAL, "a"), new Constant(REAL, toReal(Float.NaN)))); + testUnwrap("real", "a IS DISTINCT FROM nan()", new Comparison(IS_DISTINCT_FROM, new Reference(REAL, "a"), new Constant(REAL, toReal(Float.NaN)))); } @Test @@ -432,18 +432,18 @@ public void smokeTests() { // smoke tests for various type combinations for (String type : asList("SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE")) { - testUnwrap("tinyint", format("a = %s '1'", type), new ComparisonExpression(EQUAL, new SymbolReference(TINYINT, "a"), new Constant(TINYINT, 1L))); + testUnwrap("tinyint", format("a = %s '1'", type), new Comparison(EQUAL, new Reference(TINYINT, "a"), new Constant(TINYINT, 1L))); } for (String type : asList("INTEGER", "BIGINT", "REAL", "DOUBLE")) { - testUnwrap("smallint", format("a = %s '1'", type), new ComparisonExpression(EQUAL, new SymbolReference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); + testUnwrap("smallint", format("a = %s '1'", type), new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L))); } for (String type : asList("BIGINT", "DOUBLE")) { - testUnwrap("integer", format("a = %s '1'", type), new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + testUnwrap("integer", format("a = %s '1'", type), new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); } - testUnwrap("real", "a = DOUBLE '1'", new ComparisonExpression(EQUAL, new SymbolReference(REAL, "a"), new Constant(REAL, toReal(1.0f)))); + testUnwrap("real", "a = DOUBLE '1'", new Comparison(EQUAL, new Reference(REAL, "a"), new Constant(REAL, toReal(1.0f)))); } @Test @@ -454,7 +454,7 @@ public void testTermOrder() assertPlan("SELECT * FROM (VALUES REAL '1') t(a) WHERE DOUBLE '1' = a", output( filter( - new ComparisonExpression(EQUAL, new SymbolReference(REAL, "A"), new Constant(REAL, toReal(1.0f))), + new Comparison(EQUAL, new Reference(REAL, "A"), new Constant(REAL, toReal(1.0f))), values("A")))); } @@ -470,75 +470,75 @@ public void testCastDateToTimestampWithTimeZone() Session losAngelesSession = withZone(session, TimeZoneKey.getTimeZoneKey("America/Los_Angeles")); // same zone - testUnwrap(utcSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 Europe/Warsaw'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); - testUnwrap(losAngelesSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 America/Los_Angeles'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(utcSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 Europe/Warsaw'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(losAngelesSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 America/Los_Angeles'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); // different zone - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); - testUnwrap(losAngelesSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(losAngelesSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); // maximum precision - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); - testUnwrap(losAngelesSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); + testUnwrap(losAngelesSession, "date", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-10-26")))); // DST forward -- Warsaw changed clock 1h forward on 2020-03-29T01:00 UTC (2020-03-29T02:00 local time) // Note that in given session input TIMESTAMP values 2020-03-29 02:31 and 2020-03-29 03:31 produce the same value 2020-03-29 01:31 UTC (conversion is not monotonic) // last before - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.13 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); - testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999999999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.13 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999999 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999999999 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); + testUnwrap(warsawSession, "date", "a > TIMESTAMP '2020-03-29 00:59:59.999999999999 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2020-03-29")))); // equal - testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a = TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // not equal - testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(NOT_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(NOT_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(NOT_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(NOT_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <> TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(NOT_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // less than - testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a < TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // less than or equal - testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a <= TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // greater than - testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a > TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // greater than or equal - testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a >= TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // is distinct - testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00 UTC'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); - testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00 UTC'", new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "a IS DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); // is not distinct - testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00 UTC'", new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); - testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); - testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); - testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); + testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00 UTC'", new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); + testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000 UTC'", new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); + testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000 UTC'", new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); + testUnwrap(utcSession, "date", "a IS NOT DISTINCT FROM TIMESTAMP '1981-06-22 00:00:00.000000000000 UTC'", new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22"))))); // null date literal testUnwrap("date", "CAST(a AS TIMESTAMP WITH TIME ZONE) = NULL", new Constant(BOOLEAN, null)); @@ -546,10 +546,10 @@ public void testCastDateToTimestampWithTimeZone() testUnwrap("date", "CAST(a AS TIMESTAMP WITH TIME ZONE) <= NULL", new Constant(BOOLEAN, null)); testUnwrap("date", "CAST(a AS TIMESTAMP WITH TIME ZONE) > NULL", new Constant(BOOLEAN, null)); testUnwrap("date", "CAST(a AS TIMESTAMP WITH TIME ZONE) >= NULL", new Constant(BOOLEAN, null)); - testUnwrap("date", "CAST(a AS TIMESTAMP WITH TIME ZONE) IS DISTINCT FROM NULL", new NotExpression(new IsNullPredicate(new Cast(new SymbolReference(DATE, "a"), TIMESTAMP_TZ_MILLIS)))); + testUnwrap("date", "CAST(a AS TIMESTAMP WITH TIME ZONE) IS DISTINCT FROM NULL", new Not(new IsNull(new Cast(new Reference(DATE, "a"), TIMESTAMP_TZ_MILLIS)))); // timestamp with time zone value on the left - testUnwrap(utcSession, "date", "TIMESTAMP '1981-06-22 00:00:00 UTC' = a", new ComparisonExpression(EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); + testUnwrap(utcSession, "date", "TIMESTAMP '1981-06-22 00:00:00 UTC' = a", new Comparison(EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1981-06-22")))); } @Test @@ -564,170 +564,170 @@ public void testCastTimestampToTimestampWithTimeZone() Session losAngelesSession = withZone(session, TimeZoneKey.getTimeZoneKey("America/Los_Angeles")); // same zone - testUnwrap(utcSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 11:02:18")))); - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 Europe/Warsaw'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 11:02:18")))); - testUnwrap(losAngelesSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 America/Los_Angeles'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 11:02:18")))); + testUnwrap(utcSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 11:02:18")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 Europe/Warsaw'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 11:02:18")))); + testUnwrap(losAngelesSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 America/Los_Angeles'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 11:02:18")))); // different zone - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 12:02:18")))); - testUnwrap(losAngelesSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 04:02:18")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 12:02:18")))); + testUnwrap(losAngelesSession, "timestamp(0)", "a > TIMESTAMP '2020-10-26 11:02:18 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-26 04:02:18")))); // short timestamp, short timestamp with time zone being coerced to long timestamp with time zone - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-26 12:02:18.120000")))); - testUnwrap(losAngelesSession, "timestamp(6)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-26 04:02:18.120000")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-26 12:02:18.120000")))); + testUnwrap(losAngelesSession, "timestamp(6)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-26 04:02:18.120000")))); // long timestamp, short timestamp with time zone being coerced to long timestamp with time zone - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 12:02:18.120000000")))); - testUnwrap(losAngelesSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 04:02:18.120000000")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 12:02:18.120000000")))); + testUnwrap(losAngelesSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.12 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 04:02:18.120000000")))); // long timestamp, long timestamp with time zone - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.123456 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 12:02:18.123456000")))); - testUnwrap(losAngelesSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.123456 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 04:02:18.123456000")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.123456 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 12:02:18.123456000")))); + testUnwrap(losAngelesSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.123456 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 04:02:18.123456000")))); // maximum precision - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-26 12:02:18.123456789321")))); - testUnwrap(losAngelesSession, "timestamp(12)", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-26 04:02:18.123456789321")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-26 12:02:18.123456789321")))); + testUnwrap(losAngelesSession, "timestamp(12)", "a > TIMESTAMP '2020-10-26 11:02:18.123456789321 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-26 04:02:18.123456789321")))); // DST forward -- Warsaw changed clock 1h forward on 2020-03-29T01:00 UTC (2020-03-29T02:00 local time) // Note that in given session input TIMESTAMP values 2020-03-29 02:31 and 2020-03-29 03:31 produce the same value 2020-03-29 01:31 UTC (conversion is not monotonic) // last before - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-03-29 00:59:59 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-03-29 01:59:59")))); - testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-03-29 00:59:59.999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-03-29 01:59:59.999")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-03-29 00:59:59.13 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-03-29 01:59:59.130000")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-03-29 00:59:59.999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-03-29 01:59:59.999999")))); - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-03-29 00:59:59.999999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-03-29 01:59:59.999999999")))); - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-03-29 00:59:59.999999999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-03-29 01:59:59.999999999999")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-03-29 00:59:59 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-03-29 01:59:59")))); + testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-03-29 00:59:59.999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-03-29 01:59:59.999")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-03-29 00:59:59.13 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-03-29 01:59:59.130000")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-03-29 00:59:59.999999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-03-29 01:59:59.999999")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-03-29 00:59:59.999999999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-03-29 01:59:59.999999999")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-03-29 00:59:59.999999999999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-03-29 01:59:59.999999999999")))); // first after - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-03-29 02:00:00 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-03-29 04:00:00")))); - testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-03-29 02:00:00.000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-03-29 04:00:00.000")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-03-29 02:00:00.000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-03-29 04:00:00.000000")))); - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-03-29 02:00:00.000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-03-29 04:00:00.000000000")))); - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-03-29 02:00:00.000000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-03-29 04:00:00.000000000000")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-03-29 02:00:00 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-03-29 04:00:00")))); + testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-03-29 02:00:00.000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-03-29 04:00:00.000")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-03-29 02:00:00.000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-03-29 04:00:00.000000")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-03-29 02:00:00.000000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-03-29 04:00:00.000000000")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-03-29 02:00:00.000000000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-03-29 04:00:00.000000000000")))); // DST backward -- Warsaw changed clock 1h backward on 2020-10-25T01:00 UTC (2020-03-29T03:00 local time) // Note that in given session no input TIMESTAMP value can produce TIMESTAMP WITH TIME ZONE within [2020-10-25 00:00:00 UTC, 2020-10-25 01:00:00 UTC], so '>=' is OK // last before - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 00:59:59 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 02:59:59")))); - testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 00:59:59.999 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 02:59:59.999")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 00:59:59.999999 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 02:59:59.999999")))); - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 00:59:59.999999999 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 02:59:59.999999999")))); - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 00:59:59.999999999999 UTC'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 02:59:59.999999999999")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 00:59:59 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 02:59:59")))); + testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 00:59:59.999 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 02:59:59.999")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 00:59:59.999999 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 02:59:59.999999")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 00:59:59.999999999 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 02:59:59.999999999")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 00:59:59.999999999999 UTC'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 02:59:59.999999999999")))); // first within - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 01:00:00 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 02:00:00")))); - testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 01:00:00.000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 02:00:00.000")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 01:00:00.000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 02:00:00.000000")))); - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 01:00:00.000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 02:00:00.000000000")))); - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 01:00:00.000000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 02:00:00.000000000000")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 01:00:00 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 02:00:00")))); + testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 01:00:00.000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 02:00:00.000")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 01:00:00.000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 02:00:00.000000")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 01:00:00.000000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 02:00:00.000000000")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 01:00:00.000000000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 02:00:00.000000000000")))); // last within - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 01:59:59 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 02:59:59")))); - testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 01:59:59.999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 02:59:59.999")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 01:59:59.999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 02:59:59.999999")))); - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 01:59:59.999999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 02:59:59.999999999")))); - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 01:59:59.999999999999 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 02:59:59.999999999999")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 01:59:59 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 02:59:59")))); + testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 01:59:59.999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 02:59:59.999")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 01:59:59.999999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 02:59:59.999999")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 01:59:59.999999999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 02:59:59.999999999")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 01:59:59.999999999999 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 02:59:59.999999999999")))); // first after - testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 02:00:00 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 03:00:00")))); - testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 02:00:00.000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 03:00:00.000")))); - testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 02:00:00.000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 03:00:00.000000")))); - testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 02:00:00.000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 03:00:00.000000000")))); - testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 02:00:00.000000000000 UTC'", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 03:00:00.000000000000")))); + testUnwrap(warsawSession, "timestamp(0)", "a > TIMESTAMP '2020-10-25 02:00:00 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2020-10-25 03:00:00")))); + testUnwrap(warsawSession, "timestamp(3)", "a > TIMESTAMP '2020-10-25 02:00:00.000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2020-10-25 03:00:00.000")))); + testUnwrap(warsawSession, "timestamp(6)", "a > TIMESTAMP '2020-10-25 02:00:00.000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2020-10-25 03:00:00.000000")))); + testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-25 02:00:00.000000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-25 03:00:00.000000000")))); + testUnwrap(warsawSession, "timestamp(12)", "a > TIMESTAMP '2020-10-25 02:00:00.000000000000 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2020-10-25 03:00:00.000000000000")))); } @Test public void testNoEffect() { // BIGINT->DOUBLE implicit cast is not injective if the double constant is >= 2^53 and <= double(2^63 - 1) - testUnwrap("bigint", "a = DOUBLE '9007199254740992'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, 9.007199254740992E15))); + testUnwrap("bigint", "a = DOUBLE '9007199254740992'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, 9.007199254740992E15))); - testUnwrap("bigint", "a = DOUBLE '9223372036854775807'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, 9.223372036854776E18))); + testUnwrap("bigint", "a = DOUBLE '9223372036854775807'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, 9.223372036854776E18))); // BIGINT->DOUBLE implicit cast is not injective if the double constant is <= -2^53 and >= double(-2^63 + 1) - testUnwrap("bigint", "a = DOUBLE '-9007199254740992'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, -9.007199254740992E15))); + testUnwrap("bigint", "a = DOUBLE '-9007199254740992'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, -9.007199254740992E15))); - testUnwrap("bigint", "a = DOUBLE '-9223372036854775807'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, -9.223372036854776E18))); + testUnwrap("bigint", "a = DOUBLE '-9223372036854775807'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), DOUBLE), new Constant(DOUBLE, -9.223372036854776E18))); // BIGINT->REAL implicit cast is not injective if the real constant is >= 2^23 and <= real(2^63 - 1) - testUnwrap("bigint", "a = REAL '8388608'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), REAL), new Constant(REAL, toReal(8388608.0f)))); + testUnwrap("bigint", "a = REAL '8388608'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), REAL), new Constant(REAL, toReal(8388608.0f)))); - testUnwrap("bigint", "a = REAL '9223372036854775807'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), REAL), new Constant(REAL, toReal(9.223372E18f)))); + testUnwrap("bigint", "a = REAL '9223372036854775807'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), REAL), new Constant(REAL, toReal(9.223372E18f)))); // BIGINT->REAL implicit cast is not injective if the real constant is <= -2^23 and >= real(-2^63 + 1) - testUnwrap("bigint", "a = REAL '-8388608'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), REAL), new Constant(REAL, toReal(-8388608.0f)))); + testUnwrap("bigint", "a = REAL '-8388608'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), REAL), new Constant(REAL, toReal(-8388608.0f)))); - testUnwrap("bigint", "a = REAL '-9223372036854775807'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(BIGINT, "a"), REAL), new Constant(REAL, toReal(-9.223372E18f)))); + testUnwrap("bigint", "a = REAL '-9223372036854775807'", new Comparison(EQUAL, new Cast(new Reference(BIGINT, "a"), REAL), new Constant(REAL, toReal(-9.223372E18f)))); // INTEGER->REAL implicit cast is not injective if the real constant is >= 2^23 and <= 2^31 - 1 - testUnwrap("integer", "a = REAL '8388608'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(INTEGER, "a"), REAL), new Constant(REAL, toReal(8388608.0f)))); + testUnwrap("integer", "a = REAL '8388608'", new Comparison(EQUAL, new Cast(new Reference(INTEGER, "a"), REAL), new Constant(REAL, toReal(8388608.0f)))); - testUnwrap("integer", "a = REAL '2147483647'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(INTEGER, "a"), REAL), new Constant(REAL, toReal(2.1474836E9f)))); + testUnwrap("integer", "a = REAL '2147483647'", new Comparison(EQUAL, new Cast(new Reference(INTEGER, "a"), REAL), new Constant(REAL, toReal(2.1474836E9f)))); // INTEGER->REAL implicit cast is not injective if the real constant is <= -2^23 and >= -2^31 + 1 - testUnwrap("integer", "a = REAL '-8388608'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(INTEGER, "a"), REAL), new Constant(REAL, toReal(-8388608.0f)))); + testUnwrap("integer", "a = REAL '-8388608'", new Comparison(EQUAL, new Cast(new Reference(INTEGER, "a"), REAL), new Constant(REAL, toReal(-8388608.0f)))); - testUnwrap("integer", "a = REAL '-2147483647'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(INTEGER, "a"), REAL), new Constant(REAL, toReal(-2.1474836E9f)))); + testUnwrap("integer", "a = REAL '-2147483647'", new Comparison(EQUAL, new Cast(new Reference(INTEGER, "a"), REAL), new Constant(REAL, toReal(-2.1474836E9f)))); // DECIMAL(p)->DOUBLE not injective for p > 15 - testUnwrap("decimal(16)", "a = DOUBLE '1'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(createDecimalType(16), "a"), DOUBLE), new Constant(DOUBLE, 1.0))); + testUnwrap("decimal(16)", "a = DOUBLE '1'", new Comparison(EQUAL, new Cast(new Reference(createDecimalType(16), "a"), DOUBLE), new Constant(DOUBLE, 1.0))); // DECIMAL(p)->REAL not injective for p > 7 - testUnwrap("decimal(8)", "a = REAL '1'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(createDecimalType(8), "a"), REAL), new Constant(REAL, toReal(1.0f)))); + testUnwrap("decimal(8)", "a = REAL '1'", new Comparison(EQUAL, new Cast(new Reference(createDecimalType(8), "a"), REAL), new Constant(REAL, toReal(1.0f)))); // no implicit cast between VARCHAR->INTEGER - testUnwrap("varchar", "CAST(a AS INTEGER) = INTEGER '1'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(VARCHAR, "a"), INTEGER), new Constant(INTEGER, 1L))); + testUnwrap("varchar", "CAST(a AS INTEGER) = INTEGER '1'", new Comparison(EQUAL, new Cast(new Reference(VARCHAR, "a"), INTEGER), new Constant(INTEGER, 1L))); // no implicit cast between DOUBLE->INTEGER - testUnwrap("double", "CAST(a AS INTEGER) = INTEGER '1'", new ComparisonExpression(EQUAL, new Cast(new SymbolReference(DOUBLE, "a"), INTEGER), new Constant(INTEGER, 1L))); + testUnwrap("double", "CAST(a AS INTEGER) = INTEGER '1'", new Comparison(EQUAL, new Cast(new Reference(DOUBLE, "a"), INTEGER), new Constant(INTEGER, 1L))); } @Test public void testUnwrapCastTimestampAsDate() { // equal - testUnwrap("timestamp(3)", "CAST(a AS DATE) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // not equal - testUnwrap("timestamp(3)", "CAST(a AS DATE) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // less than - testUnwrap("timestamp(3)", "CAST(a AS DATE) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); // less than or equal - testUnwrap("timestamp(3)", "CAST(a AS DATE) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); // greater than - testUnwrap("timestamp(3)", "CAST(a AS DATE) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); // greater than or equal - testUnwrap("timestamp(3)", "CAST(a AS DATE) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); // is distinct - testUnwrap("timestamp(3)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(6), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(9), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(12), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(6), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(9), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(12), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // is not distinct - testUnwrap("timestamp(3)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(6), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(9), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(12), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(3), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(6), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(9), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(12), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // null date literal testUnwrap("timestamp(3)", "CAST(a AS DATE) = NULL", new Constant(BOOLEAN, null)); @@ -735,65 +735,65 @@ public void testUnwrapCastTimestampAsDate() testUnwrap("timestamp(3)", "CAST(a AS DATE) <= NULL", new Constant(BOOLEAN, null)); testUnwrap("timestamp(3)", "CAST(a AS DATE) > NULL", new Constant(BOOLEAN, null)); testUnwrap("timestamp(3)", "CAST(a AS DATE) >= NULL", new Constant(BOOLEAN, null)); - testUnwrap("timestamp(3)", "CAST(a AS DATE) IS DISTINCT FROM NULL", new NotExpression(new IsNullPredicate(new Cast(new SymbolReference(createTimestampType(3), "a"), DATE)))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) IS DISTINCT FROM NULL", new Not(new IsNull(new Cast(new Reference(createTimestampType(3), "a"), DATE)))); // non-optimized expression on the right - testUnwrap("timestamp(3)", "CAST(a AS DATE) = DATE '1981-06-22' + INTERVAL '2' DAY", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-24 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-25 00:00:00.000")))))); + testUnwrap("timestamp(3)", "CAST(a AS DATE) = DATE '1981-06-22' + INTERVAL '2' DAY", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-24 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-25 00:00:00.000")))))); // cast on the right - testUnwrap("timestamp(3)", "DATE '1981-06-22' = CAST(a AS DATE)", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(3)", "DATE '1981-06-22' = CAST(a AS DATE)", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); } @Test public void testUnwrapConvertTimestampToDate() { // equal - testUnwrap("timestamp(3)", "date(a) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "date(a) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "date(a) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "date(a) = DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "date(a) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "date(a) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "date(a) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "date(a) = DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // not equal - testUnwrap("timestamp(3)", "date(a) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "date(a) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "date(a) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "date(a) <> DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "date(a) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "date(a) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "date(a) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "date(a) <> DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // less than - testUnwrap("timestamp(3)", "date(a) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); - testUnwrap("timestamp(6)", "date(a) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); - testUnwrap("timestamp(9)", "date(a) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "date(a) < DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "date(a) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); + testUnwrap("timestamp(6)", "date(a) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); + testUnwrap("timestamp(9)", "date(a) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "date(a) < DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); // less than or equal - testUnwrap("timestamp(3)", "date(a) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); - testUnwrap("timestamp(6)", "date(a) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); - testUnwrap("timestamp(9)", "date(a) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "date(a) <= DATE '1981-06-22'", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "date(a) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); + testUnwrap("timestamp(6)", "date(a) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); + testUnwrap("timestamp(9)", "date(a) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "date(a) <= DATE '1981-06-22'", new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); // greater than - testUnwrap("timestamp(3)", "date(a) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); - testUnwrap("timestamp(6)", "date(a) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); - testUnwrap("timestamp(9)", "date(a) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "date(a) > DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "date(a) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))); + testUnwrap("timestamp(6)", "date(a) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))); + testUnwrap("timestamp(9)", "date(a) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "date(a) > DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))); // greater than or equal - testUnwrap("timestamp(3)", "date(a) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); - testUnwrap("timestamp(6)", "date(a) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); - testUnwrap("timestamp(9)", "date(a) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); - testUnwrap("timestamp(12)", "date(a) >= DATE '1981-06-22'", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); + testUnwrap("timestamp(3)", "date(a) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")))); + testUnwrap("timestamp(6)", "date(a) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")))); + testUnwrap("timestamp(9)", "date(a) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); + testUnwrap("timestamp(12)", "date(a) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); // is distinct - testUnwrap("timestamp(3)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(6), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(9), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(12), "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(6), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(9), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(12), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // is not distinct - testUnwrap("timestamp(3)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); - testUnwrap("timestamp(6)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(6), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); - testUnwrap("timestamp(9)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(9), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); - testUnwrap("timestamp(12)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new LogicalExpression(AND, ImmutableList.of(new NotExpression(new IsNullPredicate(new SymbolReference(createTimestampType(12), "a"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); + testUnwrap("timestamp(3)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(3), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(6)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(6), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); + testUnwrap("timestamp(9)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(9), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-23 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "date(a) IS NOT DISTINCT FROM DATE '1981-06-22'", new Logical(AND, ImmutableList.of(new Not(new IsNull(new Reference(createTimestampType(12), "a"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-23 00:00:00.000000000000")))))); // null date literal testUnwrap("timestamp(3)", "date(a) = NULL", new Constant(BOOLEAN, null)); @@ -801,20 +801,20 @@ public void testUnwrapConvertTimestampToDate() testUnwrap("timestamp(3)", "date(a) <= NULL", new Constant(BOOLEAN, null)); testUnwrap("timestamp(3)", "date(a) > NULL", new Constant(BOOLEAN, null)); testUnwrap("timestamp(3)", "date(a) >= NULL", new Constant(BOOLEAN, null)); - testUnwrap("timestamp(3)", "date(a) IS DISTINCT FROM NULL", new NotExpression(new IsNullPredicate(new Cast(new SymbolReference(createTimestampType(3), "a"), DATE)))); + testUnwrap("timestamp(3)", "date(a) IS DISTINCT FROM NULL", new Not(new IsNull(new Cast(new Reference(createTimestampType(3), "a"), DATE)))); // non-optimized expression on the right - testUnwrap("timestamp(3)", "date(a) = DATE '1981-06-22' + INTERVAL '2' DAY", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-24 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-25 00:00:00.000")))))); + testUnwrap("timestamp(3)", "date(a) = DATE '1981-06-22' + INTERVAL '2' DAY", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-24 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-25 00:00:00.000")))))); // cast on the right - testUnwrap("timestamp(3)", "DATE '1981-06-22' = date(a)", new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); + testUnwrap("timestamp(3)", "DATE '1981-06-22' = date(a)", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); } private void testRemoveFilter(String inputType, String inputPredicate) { assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s AND rand() = 42", inputType, inputPredicate), output( - filter(new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)), + filter(new Comparison(EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)), values("a")))); } @@ -825,15 +825,15 @@ private void testUnwrap(String inputType, String inputPredicate, Expression expe private void testUnwrap(Session session, String inputType, String inputPredicate, Expression expected) { - Expression antiOptimization = new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)); - if (expected instanceof LogicalExpression logical && logical.getOperator() == OR) { - expected = new LogicalExpression(OR, ImmutableList.builder() - .addAll(logical.getTerms()) + Expression antiOptimization = new Comparison(EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)); + if (expected instanceof Logical logical && logical.operator() == OR) { + expected = new Logical(OR, ImmutableList.builder() + .addAll(logical.terms()) .add(antiOptimization) .build()); } else { - expected = new LogicalExpression(OR, ImmutableList.of(expected, antiOptimization)); + expected = new Logical(OR, ImmutableList.of(expected, antiOptimization)); } assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE (%s) OR rand() = 42", inputType, inputPredicate), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java index 9855e764e563..ebfa961eead9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java @@ -17,15 +17,15 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.LongTimestamp; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.type.DateTimes; import io.trino.util.DateTimeUtils; @@ -46,13 +46,13 @@ import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.output; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -74,204 +74,204 @@ public class TestUnwrapYearInComparison @Test public void testEquals() { - testUnwrap("date", "year(a) = -0001", new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31")))); - testUnwrap("date", "year(a) = 1960", new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31")))); - testUnwrap("date", "year(a) = 2022", new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); - testUnwrap("date", "year(a) = 9999", new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31")))); - - testUnwrap("timestamp", "year(a) = -0001", new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) = 1960", new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) = 9999", new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999")))); - - testUnwrap("timestamp(0)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59")))); - testUnwrap("timestamp(1)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9")))); - testUnwrap("timestamp(2)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99")))); - testUnwrap("timestamp(3)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); - testUnwrap("timestamp(4)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999")))); - testUnwrap("timestamp(5)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999")))); - testUnwrap("timestamp(6)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999")))); - testUnwrap("timestamp(7)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999")))); - testUnwrap("timestamp(8)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999")))); - testUnwrap("timestamp(9)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999")))); - testUnwrap("timestamp(10)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999")))); - testUnwrap("timestamp(11)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999")))); - testUnwrap("timestamp(12)", "year(a) = 2022", new BetweenPredicate(new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999")))); + testUnwrap("date", "year(a) = -0001", new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31")))); + testUnwrap("date", "year(a) = 1960", new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31")))); + testUnwrap("date", "year(a) = 2022", new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); + testUnwrap("date", "year(a) = 9999", new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31")))); + + testUnwrap("timestamp", "year(a) = -0001", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) = 1960", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) = 2022", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) = 9999", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999")))); + + testUnwrap("timestamp(0)", "year(a) = 2022", new Between(new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59")))); + testUnwrap("timestamp(1)", "year(a) = 2022", new Between(new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9")))); + testUnwrap("timestamp(2)", "year(a) = 2022", new Between(new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99")))); + testUnwrap("timestamp(3)", "year(a) = 2022", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); + testUnwrap("timestamp(4)", "year(a) = 2022", new Between(new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999")))); + testUnwrap("timestamp(5)", "year(a) = 2022", new Between(new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999")))); + testUnwrap("timestamp(6)", "year(a) = 2022", new Between(new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999")))); + testUnwrap("timestamp(7)", "year(a) = 2022", new Between(new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999")))); + testUnwrap("timestamp(8)", "year(a) = 2022", new Between(new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999")))); + testUnwrap("timestamp(9)", "year(a) = 2022", new Between(new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999")))); + testUnwrap("timestamp(10)", "year(a) = 2022", new Between(new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999")))); + testUnwrap("timestamp(11)", "year(a) = 2022", new Between(new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999")))); + testUnwrap("timestamp(12)", "year(a) = 2022", new Between(new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999")))); } @Test public void testInPredicate() { - testUnwrap("date", "year(a) IN (1000, 1400, 1800)", new LogicalExpression(OR, ImmutableList.of(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1000-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1000-12-31"))), new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1400-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1400-12-31"))), new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1800-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1800-12-31")))))); - testUnwrap("timestamp", "year(a) IN (1000, 1400, 1800)", new LogicalExpression(OR, ImmutableList.of(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1000-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1000-12-31 23:59:59.999"))), new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1400-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1400-12-31 23:59:59.999"))), new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1800-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1800-12-31 23:59:59.999")))))); + testUnwrap("date", "year(a) IN (1000, 1400, 1800)", new Logical(OR, ImmutableList.of(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1000-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1000-12-31"))), new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1400-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1400-12-31"))), new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1800-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1800-12-31")))))); + testUnwrap("timestamp", "year(a) IN (1000, 1400, 1800)", new Logical(OR, ImmutableList.of(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1000-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1000-12-31 23:59:59.999"))), new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1400-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1400-12-31 23:59:59.999"))), new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1800-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1800-12-31 23:59:59.999")))))); } @Test public void testNotEquals() { - testUnwrap("date", "year(a) <> -0001", new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31"))))); - testUnwrap("date", "year(a) <> 1960", new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31"))))); - testUnwrap("date", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31"))))); - testUnwrap("date", "year(a) <> 9999", new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31"))))); - - testUnwrap("timestamp", "year(a) <> -0001", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999"))))); - testUnwrap("timestamp", "year(a) <> 1960", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999"))))); - testUnwrap("timestamp", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))); - testUnwrap("timestamp", "year(a) <> 9999", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999"))))); - - testUnwrap("timestamp(0)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59"))))); - testUnwrap("timestamp(1)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9"))))); - testUnwrap("timestamp(2)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99"))))); - testUnwrap("timestamp(3)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))); - testUnwrap("timestamp(4)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999"))))); - testUnwrap("timestamp(5)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999"))))); - testUnwrap("timestamp(6)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999"))))); - testUnwrap("timestamp(7)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999"))))); - testUnwrap("timestamp(8)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999"))))); - testUnwrap("timestamp(9)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999"))))); - testUnwrap("timestamp(10)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999"))))); - testUnwrap("timestamp(11)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999"))))); - testUnwrap("timestamp(12)", "year(a) <> 2022", new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999"))))); + testUnwrap("date", "year(a) <> -0001", new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31"))))); + testUnwrap("date", "year(a) <> 1960", new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31"))))); + testUnwrap("date", "year(a) <> 2022", new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31"))))); + testUnwrap("date", "year(a) <> 9999", new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31"))))); + + testUnwrap("timestamp", "year(a) <> -0001", new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999"))))); + testUnwrap("timestamp", "year(a) <> 1960", new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999"))))); + testUnwrap("timestamp", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))); + testUnwrap("timestamp", "year(a) <> 9999", new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999"))))); + + testUnwrap("timestamp(0)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59"))))); + testUnwrap("timestamp(1)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9"))))); + testUnwrap("timestamp(2)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99"))))); + testUnwrap("timestamp(3)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))); + testUnwrap("timestamp(4)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999"))))); + testUnwrap("timestamp(5)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999"))))); + testUnwrap("timestamp(6)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999"))))); + testUnwrap("timestamp(7)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999"))))); + testUnwrap("timestamp(8)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999"))))); + testUnwrap("timestamp(9)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999"))))); + testUnwrap("timestamp(10)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999"))))); + testUnwrap("timestamp(11)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999"))))); + testUnwrap("timestamp(12)", "year(a) <> 2022", new Not(new Between(new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999"))))); } @Test public void testLessThan() { - testUnwrap("date", "year(a) < -0001", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")))); - testUnwrap("date", "year(a) < 1960", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")))); - testUnwrap("date", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")))); - testUnwrap("date", "year(a) < 9999", new ComparisonExpression(LESS_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")))); - - testUnwrap("timestamp", "year(a) < -0001", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")))); - testUnwrap("timestamp", "year(a) < 1960", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")))); - testUnwrap("timestamp", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); - testUnwrap("timestamp", "year(a) < 9999", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")))); - - testUnwrap("timestamp(0)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")))); - testUnwrap("timestamp(1)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")))); - testUnwrap("timestamp(2)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")))); - testUnwrap("timestamp(3)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); - testUnwrap("timestamp(4)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")))); - testUnwrap("timestamp(5)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")))); - testUnwrap("timestamp(6)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")))); - testUnwrap("timestamp(7)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")))); - testUnwrap("timestamp(8)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")))); - testUnwrap("timestamp(9)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")))); - testUnwrap("timestamp(10)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")))); - testUnwrap("timestamp(11)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")))); - testUnwrap("timestamp(12)", "year(a) < 2022", new ComparisonExpression(LESS_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")))); + testUnwrap("date", "year(a) < -0001", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")))); + testUnwrap("date", "year(a) < 1960", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")))); + testUnwrap("date", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")))); + testUnwrap("date", "year(a) < 9999", new Comparison(LESS_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")))); + + testUnwrap("timestamp", "year(a) < -0001", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")))); + testUnwrap("timestamp", "year(a) < 1960", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")))); + testUnwrap("timestamp", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); + testUnwrap("timestamp", "year(a) < 9999", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")))); + + testUnwrap("timestamp(0)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")))); + testUnwrap("timestamp(1)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")))); + testUnwrap("timestamp(2)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")))); + testUnwrap("timestamp(3)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); + testUnwrap("timestamp(4)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")))); + testUnwrap("timestamp(5)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")))); + testUnwrap("timestamp(6)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")))); + testUnwrap("timestamp(7)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")))); + testUnwrap("timestamp(8)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")))); + testUnwrap("timestamp(9)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")))); + testUnwrap("timestamp(10)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")))); + testUnwrap("timestamp(11)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")))); + testUnwrap("timestamp(12)", "year(a) < 2022", new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")))); } @Test public void testLessThanOrEqual() { - testUnwrap("date", "year(a) <= -0001", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31")))); - testUnwrap("date", "year(a) <= 1960", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31")))); - testUnwrap("date", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); - testUnwrap("date", "year(a) <= 9999", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31")))); - - testUnwrap("timestamp", "year(a) <= -0001", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) <= 1960", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) <= 9999", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999")))); - - testUnwrap("timestamp(0)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59")))); - testUnwrap("timestamp(1)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9")))); - testUnwrap("timestamp(2)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99")))); - testUnwrap("timestamp(3)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); - testUnwrap("timestamp(4)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999")))); - testUnwrap("timestamp(5)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999")))); - testUnwrap("timestamp(6)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999")))); - testUnwrap("timestamp(7)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999")))); - testUnwrap("timestamp(8)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999")))); - testUnwrap("timestamp(9)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999")))); - testUnwrap("timestamp(10)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999")))); - testUnwrap("timestamp(11)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999")))); - testUnwrap("timestamp(12)", "year(a) <= 2022", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999")))); + testUnwrap("date", "year(a) <= -0001", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31")))); + testUnwrap("date", "year(a) <= 1960", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31")))); + testUnwrap("date", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); + testUnwrap("date", "year(a) <= 9999", new Comparison(LESS_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31")))); + + testUnwrap("timestamp", "year(a) <= -0001", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) <= 1960", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) <= 9999", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999")))); + + testUnwrap("timestamp(0)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59")))); + testUnwrap("timestamp(1)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9")))); + testUnwrap("timestamp(2)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99")))); + testUnwrap("timestamp(3)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); + testUnwrap("timestamp(4)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999")))); + testUnwrap("timestamp(5)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999")))); + testUnwrap("timestamp(6)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999")))); + testUnwrap("timestamp(7)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999")))); + testUnwrap("timestamp(8)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999")))); + testUnwrap("timestamp(9)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999")))); + testUnwrap("timestamp(10)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999")))); + testUnwrap("timestamp(11)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999")))); + testUnwrap("timestamp(12)", "year(a) <= 2022", new Comparison(LESS_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999")))); } @Test public void testGreaterThan() { - testUnwrap("date", "year(a) > -0001", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31")))); - testUnwrap("date", "year(a) > 1960", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31")))); - testUnwrap("date", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); - testUnwrap("date", "year(a) > 9999", new ComparisonExpression(GREATER_THAN, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31")))); - - testUnwrap("timestamp", "year(a) > -0001", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) > 1960", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); - testUnwrap("timestamp", "year(a) > 9999", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999")))); - - testUnwrap("timestamp(0)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59")))); - testUnwrap("timestamp(1)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9")))); - testUnwrap("timestamp(2)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99")))); - testUnwrap("timestamp(3)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); - testUnwrap("timestamp(4)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999")))); - testUnwrap("timestamp(5)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999")))); - testUnwrap("timestamp(6)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999")))); - testUnwrap("timestamp(7)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999")))); - testUnwrap("timestamp(8)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999")))); - testUnwrap("timestamp(9)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999")))); - testUnwrap("timestamp(10)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999")))); - testUnwrap("timestamp(11)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999")))); - testUnwrap("timestamp(12)", "year(a) > 2022", new ComparisonExpression(GREATER_THAN, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999")))); + testUnwrap("date", "year(a) > -0001", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31")))); + testUnwrap("date", "year(a) > 1960", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31")))); + testUnwrap("date", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); + testUnwrap("date", "year(a) > 9999", new Comparison(GREATER_THAN, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31")))); + + testUnwrap("timestamp", "year(a) > -0001", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) > 1960", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); + testUnwrap("timestamp", "year(a) > 9999", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999")))); + + testUnwrap("timestamp(0)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59")))); + testUnwrap("timestamp(1)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9")))); + testUnwrap("timestamp(2)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99")))); + testUnwrap("timestamp(3)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999")))); + testUnwrap("timestamp(4)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999")))); + testUnwrap("timestamp(5)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999")))); + testUnwrap("timestamp(6)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999")))); + testUnwrap("timestamp(7)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999")))); + testUnwrap("timestamp(8)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999")))); + testUnwrap("timestamp(9)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999")))); + testUnwrap("timestamp(10)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999")))); + testUnwrap("timestamp(11)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999")))); + testUnwrap("timestamp(12)", "year(a) > 2022", new Comparison(GREATER_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999")))); } @Test public void testGreaterThanOrEqual() { - testUnwrap("date", "year(a) >= -0001", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")))); - testUnwrap("date", "year(a) >= 1960", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")))); - testUnwrap("date", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")))); - testUnwrap("date", "year(a) >= 9999", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")))); - - testUnwrap("timestamp", "year(a) >= -0001", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")))); - testUnwrap("timestamp", "year(a) >= 1960", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")))); - testUnwrap("timestamp", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); - testUnwrap("timestamp", "year(a) >= 9999", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")))); - - testUnwrap("timestamp(0)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")))); - testUnwrap("timestamp(1)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")))); - testUnwrap("timestamp(2)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")))); - testUnwrap("timestamp(3)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); - testUnwrap("timestamp(4)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")))); - testUnwrap("timestamp(5)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")))); - testUnwrap("timestamp(6)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")))); - testUnwrap("timestamp(7)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")))); - testUnwrap("timestamp(8)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")))); - testUnwrap("timestamp(9)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")))); - testUnwrap("timestamp(10)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")))); - testUnwrap("timestamp(11)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")))); - testUnwrap("timestamp(12)", "year(a) >= 2022", new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")))); + testUnwrap("date", "year(a) >= -0001", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")))); + testUnwrap("date", "year(a) >= 1960", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")))); + testUnwrap("date", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")))); + testUnwrap("date", "year(a) >= 9999", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")))); + + testUnwrap("timestamp", "year(a) >= -0001", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")))); + testUnwrap("timestamp", "year(a) >= 1960", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")))); + testUnwrap("timestamp", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); + testUnwrap("timestamp", "year(a) >= 9999", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")))); + + testUnwrap("timestamp(0)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")))); + testUnwrap("timestamp(1)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")))); + testUnwrap("timestamp(2)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")))); + testUnwrap("timestamp(3)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")))); + testUnwrap("timestamp(4)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")))); + testUnwrap("timestamp(5)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")))); + testUnwrap("timestamp(6)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")))); + testUnwrap("timestamp(7)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")))); + testUnwrap("timestamp(8)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")))); + testUnwrap("timestamp(9)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")))); + testUnwrap("timestamp(10)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")))); + testUnwrap("timestamp(11)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")))); + testUnwrap("timestamp(12)", "year(a) >= 2022", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")))); } @Test public void testDistinctFrom() { - testUnwrap("date", "year(a) IS DISTINCT FROM -0001", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(DATE, "a")), new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31"))))))); - testUnwrap("date", "year(a) IS DISTINCT FROM 1960", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(DATE, "a")), new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31"))))))); - testUnwrap("date", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(DATE, "a")), new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31"))))))); - testUnwrap("date", "year(a) IS DISTINCT FROM 9999", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(DATE, "a")), new NotExpression(new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31"))))))); - - testUnwrap("timestamp", "year(a) IS DISTINCT FROM -0001", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999"))))))); - testUnwrap("timestamp", "year(a) IS DISTINCT FROM 1960", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999"))))))); - testUnwrap("timestamp", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))))); - testUnwrap("timestamp", "year(a) IS DISTINCT FROM 9999", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999"))))))); - - testUnwrap("timestamp(0)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(0), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59"))))))); - testUnwrap("timestamp(1)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(1), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9"))))))); - testUnwrap("timestamp(2)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(2), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99"))))))); - testUnwrap("timestamp(3)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(3), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))))); - testUnwrap("timestamp(4)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(4), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999"))))))); - testUnwrap("timestamp(5)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(5), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999"))))))); - testUnwrap("timestamp(6)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(6), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999"))))))); - testUnwrap("timestamp(7)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(7), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999"))))))); - testUnwrap("timestamp(8)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(8), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999"))))))); - testUnwrap("timestamp(9)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(9), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999"))))))); - testUnwrap("timestamp(10)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(10), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999"))))))); - testUnwrap("timestamp(11)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(11), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999"))))))); - testUnwrap("timestamp(12)", "year(a) IS DISTINCT FROM 2022", new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(createTimestampType(12), "a")), new NotExpression(new BetweenPredicate(new SymbolReference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999"))))))); + testUnwrap("date", "year(a) IS DISTINCT FROM -0001", new Logical(OR, ImmutableList.of(new IsNull(new Reference(DATE, "a")), new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("-0001-12-31"))))))); + testUnwrap("date", "year(a) IS DISTINCT FROM 1960", new Logical(OR, ImmutableList.of(new IsNull(new Reference(DATE, "a")), new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("1960-12-31"))))))); + testUnwrap("date", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(DATE, "a")), new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31"))))))); + testUnwrap("date", "year(a) IS DISTINCT FROM 9999", new Logical(OR, ImmutableList.of(new IsNull(new Reference(DATE, "a")), new Not(new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("9999-12-31"))))))); + + testUnwrap("timestamp", "year(a) IS DISTINCT FROM -0001", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "-0001-12-31 23:59:59.999"))))))); + testUnwrap("timestamp", "year(a) IS DISTINCT FROM 1960", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1960-12-31 23:59:59.999"))))))); + testUnwrap("timestamp", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))))); + testUnwrap("timestamp", "year(a) IS DISTINCT FROM 9999", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "9999-12-31 23:59:59.999"))))))); + + testUnwrap("timestamp(0)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(0), "a")), new Not(new Between(new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-01-01 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "2022-12-31 23:59:59"))))))); + testUnwrap("timestamp(1)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(1), "a")), new Not(new Between(new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-01-01 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "2022-12-31 23:59:59.9"))))))); + testUnwrap("timestamp(2)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(2), "a")), new Not(new Between(new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-01-01 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "2022-12-31 23:59:59.99"))))))); + testUnwrap("timestamp(3)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Not(new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2022-12-31 23:59:59.999"))))))); + testUnwrap("timestamp(4)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(4), "a")), new Not(new Between(new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-01-01 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "2022-12-31 23:59:59.9999"))))))); + testUnwrap("timestamp(5)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(5), "a")), new Not(new Between(new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-01-01 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "2022-12-31 23:59:59.99999"))))))); + testUnwrap("timestamp(6)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(6), "a")), new Not(new Between(new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-01-01 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "2022-12-31 23:59:59.999999"))))))); + testUnwrap("timestamp(7)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(7), "a")), new Not(new Between(new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-01-01 00:00:00.0000000")), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "2022-12-31 23:59:59.9999999"))))))); + testUnwrap("timestamp(8)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(8), "a")), new Not(new Between(new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-01-01 00:00:00.00000000")), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "2022-12-31 23:59:59.99999999"))))))); + testUnwrap("timestamp(9)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(9), "a")), new Not(new Between(new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-01-01 00:00:00.000000000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2022-12-31 23:59:59.999999999"))))))); + testUnwrap("timestamp(10)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(10), "a")), new Not(new Between(new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-01-01 00:00:00.0000000000")), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "2022-12-31 23:59:59.9999999999"))))))); + testUnwrap("timestamp(11)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(11), "a")), new Not(new Between(new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-01-01 00:00:00.00000000000")), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "2022-12-31 23:59:59.99999999999"))))))); + testUnwrap("timestamp(12)", "year(a) IS DISTINCT FROM 2022", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(12), "a")), new Not(new Between(new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-01-01 00:00:00.000000000000")), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "2022-12-31 23:59:59.999999999999"))))))); } @Test @@ -284,8 +284,8 @@ public void testNull() @Test public void testNaN() { - testUnwrap("date", "year(a) = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new FunctionCall(YEAR_DATE, ImmutableList.of(new SymbolReference(DATE, "a")))), new Constant(BOOLEAN, null)))); - testUnwrap("timestamp", "year(a) = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new FunctionCall(YEAR_TIMESTAMP_3, ImmutableList.of(new SymbolReference(createTimestampType(3), "a")))), new Constant(BOOLEAN, null)))); + testUnwrap("date", "year(a) = nan()", new Logical(AND, ImmutableList.of(new IsNull(new Call(YEAR_DATE, ImmutableList.of(new Reference(DATE, "a")))), new Constant(BOOLEAN, null)))); + testUnwrap("timestamp", "year(a) = nan()", new Logical(AND, ImmutableList.of(new IsNull(new Call(YEAR_TIMESTAMP_3, ImmutableList.of(new Reference(createTimestampType(3), "a")))), new Constant(BOOLEAN, null)))); } @Test @@ -293,7 +293,7 @@ public void smokeTests() { // smoke tests for various type combinations for (String type : asList("SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE")) { - testUnwrap("date", format("year(a) = %s '2022'", type), new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); + testUnwrap("date", format("year(a) = %s '2022'", type), new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31")))); } } @@ -305,15 +305,15 @@ public void testTermOrder() assertPlan("SELECT * FROM (VALUES DATE '2022-01-01') t(a) WHERE 2022 = year(a)", output( filter( - new BetweenPredicate(new SymbolReference(DATE, "A"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31"))), + new Between(new Reference(DATE, "A"), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2022-12-31"))), values("A")))); } @Test public void testLeapYear() { - testUnwrap("date", "year(a) = 2024", new BetweenPredicate(new SymbolReference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2024-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2024-12-31")))); - testUnwrap("timestamp", "year(a) = 2024", new BetweenPredicate(new SymbolReference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2024-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2024-12-31 23:59:59.999")))); + testUnwrap("date", "year(a) = 2024", new Between(new Reference(DATE, "a"), new Constant(DATE, (long) DateTimeUtils.parseDate("2024-01-01")), new Constant(DATE, (long) DateTimeUtils.parseDate("2024-12-31")))); + testUnwrap("timestamp", "year(a) = 2024", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2024-01-01 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "2024-12-31 23:59:59.999")))); } @Test @@ -342,15 +342,15 @@ private static long toEpochMicros(LocalDateTime localDateTime) private void testUnwrap(String inputType, String inputPredicate, Expression expected) { - Expression antiOptimization = new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)); - if (expected instanceof LogicalExpression logical && logical.getOperator() == OR) { - expected = new LogicalExpression(OR, ImmutableList.builder() - .addAll(logical.getTerms()) + Expression antiOptimization = new Comparison(EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)); + if (expected instanceof Logical logical && logical.operator() == OR) { + expected = new Logical(OR, ImmutableList.builder() + .addAll(logical.terms()) .add(antiOptimization) .build()); } else { - expected = new LogicalExpression(OR, ImmutableList.of(expected, antiOptimization)); + expected = new Logical(OR, ImmutableList.of(expected, antiOptimization)); } String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s OR rand() = 42", inputType, inputPredicate); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java index d3aa33612cfa..e3000df9d924 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java @@ -19,12 +19,12 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.SortOrder; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.FilterNode; @@ -36,7 +36,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -81,7 +81,7 @@ public void testPreprojectExpression() "max_result", windowFunction("max", ImmutableList.of("b"), DEFAULT_FRAME)), anyTree(project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))), anyTree(values("a", "b")))))); assertPlan(sql, CREATED, pattern); @@ -113,12 +113,12 @@ public void testPreprojectExpressions() Optional.empty(), Optional.empty()))), project( - ImmutableMap.of("frame_start", expression(new FunctionCall(SUBTRACT_INTEGER, ImmutableList.of(new SymbolReference(INTEGER, "expr_b"), new SymbolReference(INTEGER, "expr_c"))))), + ImmutableMap.of("frame_start", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr_b"), new Reference(INTEGER, "expr_c"))))), anyTree(project( ImmutableMap.of( - "expr_a", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))), - "expr_b", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 2L))), - "expr_c", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 3L)))), + "expr_a", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))), + "expr_b", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), + "expr_c", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "c"), new Constant(INTEGER, 3L)))), anyTree(values("a", "b", "c"))))))); assertPlan(sql, CREATED, pattern); @@ -141,9 +141,9 @@ public void testWindowFunctionsInSelectAndOrderBy() "max_result", windowFunction("max", ImmutableList.of("minus_a"), DEFAULT_FRAME)), any(project( - ImmutableMap.of("order_by_window_sortkey", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "minus_a"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("order_by_window_sortkey", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "minus_a"), new Constant(INTEGER, 1L)))), project( - ImmutableMap.of("minus_a", expression(new ArithmeticNegation(new SymbolReference(INTEGER, "a")))), + ImmutableMap.of("minus_a", expression(new Negation(new Reference(INTEGER, "a")))), window( windowMatcherBuilder -> windowMatcherBuilder .specification(specification( @@ -154,7 +154,7 @@ public void testWindowFunctionsInSelectAndOrderBy() "array_agg_result", windowFunction("array_agg", ImmutableList.of("a"), DEFAULT_FRAME)), anyTree(project( - ImmutableMap.of("select_window_sortkey", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("select_window_sortkey", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)))), anyTree(values("a")))))))))))); assertPlan(sql, CREATED, pattern); @@ -187,16 +187,16 @@ public void testWindowWithFrameCoercions() Optional.of(new Symbol(UNKNOWN, "frame_bound")), Optional.of(new Symbol(UNKNOWN, "coerced_sortkey"))))), project(// frame bound value computation - ImmutableMap.of("frame_bound", expression(new FunctionCall(ADD_DOUBLE, ImmutableList.of(new SymbolReference(DOUBLE, "coerced_sortkey"), new SymbolReference(INTEGER, "frame_offset"))))), + ImmutableMap.of("frame_bound", expression(new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "coerced_sortkey"), new Reference(INTEGER, "frame_offset"))))), project(// sort key coercion to frame bound type - ImmutableMap.of("coerced_sortkey", expression(new Cast(new SymbolReference(INTEGER, "sortkey"), DOUBLE))), + ImmutableMap.of("coerced_sortkey", expression(new Cast(new Reference(INTEGER, "sortkey"), DOUBLE))), node(FilterNode.class, project(project( ImmutableMap.of( // sort key based on "a" in source scope - "sortkey", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))), + "sortkey", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))), // frame offset based on "a" in output scope - "frame_offset", expression(new ArithmeticBinaryExpression(ADD_DOUBLE, ADD, new SymbolReference(DOUBLE, "new_a"), new Constant(DOUBLE, 1.0)))), + "frame_offset", expression(new Arithmetic(ADD_DOUBLE, ADD, new Reference(DOUBLE, "new_a"), new Constant(DOUBLE, 1.0)))), project(// output expression ImmutableMap.of("new_a", expression(new Constant(DOUBLE, 2E0))), project(project(values("a"))))))))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java index 84cc7eb75bce..f2c0059d7dca 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java @@ -21,11 +21,11 @@ import io.trino.spi.connector.SortOrder; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.WindowNode; @@ -43,8 +43,8 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -99,16 +99,16 @@ public void testFramePrecedingWithSortKeyCoercions() Optional.empty(), Optional.empty()))), project(// coerce sort key to compare sort key values with frame start values - ImmutableMap.of("key_for_frame_start_comparison", expression(new Cast(new SymbolReference(INTEGER, "key"), createDecimalType(12, 1)))), + ImmutableMap.of("key_for_frame_start_comparison", expression(new Cast(new Reference(INTEGER, "key"), createDecimalType(12, 1)))), project(// calculate frame start value (sort key - frame offset) - ImmutableMap.of("frame_start_value", expression(new FunctionCall(SUBTRACT_DECIMAL_10_0, ImmutableList.of(new SymbolReference(INTEGER, "key_for_frame_start_calculation"), new SymbolReference(DOUBLE, "x"))))), + ImmutableMap.of("frame_start_value", expression(new Call(SUBTRACT_DECIMAL_10_0, ImmutableList.of(new Reference(INTEGER, "key_for_frame_start_calculation"), new Reference(DOUBLE, "x"))))), project(// coerce sort key to calculate frame start values - ImmutableMap.of("key_for_frame_start_calculation", expression(new Cast(new SymbolReference(INTEGER, "key"), createDecimalType(10, 0)))), + ImmutableMap.of("key_for_frame_start_calculation", expression(new Cast(new Reference(INTEGER, "key"), createDecimalType(10, 0)))), filter(// validate offset values ifExpression( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createDecimalType(2, 1), "x"), new Constant(createDecimalType(2, 1), 0L)), - TRUE_LITERAL, - new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createDecimalType(2, 1), "x"), new Constant(createDecimalType(2, 1), 0L)), + TRUE, + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), anyTree( values( ImmutableList.of("key", "x"), @@ -147,16 +147,16 @@ public void testFrameFollowingWithOffsetCoercion() Optional.of(new Symbol(UNKNOWN, "frame_end_value")), Optional.of(new Symbol(UNKNOWN, "key_for_frame_end_comparison"))))), project(// coerce sort key to compare sort key values with frame end values - ImmutableMap.of("key_for_frame_end_comparison", expression(new Cast(new SymbolReference(INTEGER, "key"), createDecimalType(12, 1)))), + ImmutableMap.of("key_for_frame_end_comparison", expression(new Cast(new Reference(INTEGER, "key"), createDecimalType(12, 1)))), project(// calculate frame end value (sort key + frame offset) - ImmutableMap.of("frame_end_value", expression(new FunctionCall(ADD_DECIMAL_10_0, ImmutableList.of(new SymbolReference(INTEGER, "key"), new SymbolReference(INTEGER, "offset"))))), + ImmutableMap.of("frame_end_value", expression(new Call(ADD_DECIMAL_10_0, ImmutableList.of(new Reference(INTEGER, "key"), new Reference(INTEGER, "offset"))))), filter(// validate offset values ifExpression( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(createDecimalType(10, 0), "offset"), new Constant(createDecimalType(10, 0), 0L)), - TRUE_LITERAL, - new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createDecimalType(10, 0), "offset"), new Constant(createDecimalType(10, 0), 0L)), + TRUE, + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), project(// coerce offset value to calculate frame end values - ImmutableMap.of("offset", expression(new Cast(new SymbolReference(DOUBLE, "x"), createDecimalType(10, 0)))), + ImmutableMap.of("offset", expression(new Cast(new Reference(DOUBLE, "x"), createDecimalType(10, 0)))), anyTree( values( ImmutableList.of("key", "x"), @@ -195,19 +195,19 @@ public void testFramePrecedingFollowingNoCoercions() Optional.of(new Symbol(UNKNOWN, "frame_end_value")), Optional.of(new Symbol(UNKNOWN, "key"))))), project(// calculate frame end value (sort key + frame end offset) - ImmutableMap.of("frame_end_value", expression(new FunctionCall(ADD_INTEGER, ImmutableList.of(new SymbolReference(INTEGER, "key"), new SymbolReference(INTEGER, "y"))))), + ImmutableMap.of("frame_end_value", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "key"), new Reference(INTEGER, "y"))))), filter(// validate frame end offset values ifExpression( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 0L)), - TRUE_LITERAL, - new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "y"), new Constant(INTEGER, 0L)), + TRUE, + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), project(// calculate frame start value (sort key - frame start offset) - ImmutableMap.of("frame_start_value", expression(new FunctionCall(SUBTRACT_INTEGER, ImmutableList.of(new SymbolReference(INTEGER, "key"), new SymbolReference(INTEGER, "x"))))), + ImmutableMap.of("frame_start_value", expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "key"), new Reference(INTEGER, "x"))))), filter(// validate frame start offset values ifExpression( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 0L)), - TRUE_LITERAL, - new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "x"), new Constant(INTEGER, 0L)), + TRUE, + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), anyTree( values( ImmutableList.of("key", "x", "y"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java index 3d1db4ff9fa4..9544d31855ab 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java @@ -69,7 +69,7 @@ public AggregationFunction getExpectedValue(SymbolAliases aliases) ImmutableMap.Builder orders = ImmutableMap.builder(); for (PlanMatchPattern.Ordering ordering : this.orderBy) { - Symbol symbol = new Symbol(UNKNOWN, aliases.get(ordering.getField()).getName()); + Symbol symbol = new Symbol(UNKNOWN, aliases.get(ordering.getField()).name()); fields.add(symbol); orders.put(symbol, ordering.getSortOrder()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationMatcher.java index 9f6a5ecdf6f8..d7d56fd2ac7e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationMatcher.java @@ -86,7 +86,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses .collect(toImmutableSet()); Set expectedMasks = masks.stream() - .map(name -> new Symbol(UNKNOWN, symbolAliases.get(name).getName())) + .map(name -> new Symbol(UNKNOWN, symbolAliases.get(name).name())) .collect(toImmutableSet()); if (!actualMasks.equals(expectedMasks)) { @@ -116,7 +116,7 @@ static boolean matches(Collection expectedAliases, Collection ac List expectedSymbols = expectedAliases .stream() - .map(alias -> new Symbol(UNKNOWN, symbolAliases.get(alias).getName())) + .map(alias -> new Symbol(UNKNOWN, symbolAliases.get(alias).name())) .collect(toImmutableList()); for (Symbol symbol : expectedSymbols) { if (!actualSymbols.contains(symbol)) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionAndValuePointersMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionAndValuePointersMatcher.java index 485d93049ec6..5b17fa738de2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionAndValuePointersMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionAndValuePointersMatcher.java @@ -34,8 +34,7 @@ public static boolean matches(ExpressionAndValuePointers expected, ExpressionAnd Assignment expectedAssignment = expected.getAssignments().get(i); boolean match = switch (actualAssignment.valuePointer()) { - case ScalarValuePointer actualPointer when expectedAssignment.valuePointer() instanceof ScalarValuePointer expectedPointer -> - aliases.get(expectedPointer.getInputSymbol().getName()).getName().equals(actualPointer.getInputSymbol().getName()); + case ScalarValuePointer actualPointer when expectedAssignment.valuePointer() instanceof ScalarValuePointer expectedPointer -> aliases.get(expectedPointer.getInputSymbol().getName()).name().equals(actualPointer.getInputSymbol().getName()); case AggregationValuePointer actualPointer when expectedAssignment.valuePointer() instanceof AggregationValuePointer expectedPointer -> { if (!expectedPointer.getFunction().equals(actualPointer.getFunction()) || !expectedPointer.getSetDescriptor().equals(actualPointer.getSetDescriptor()) || diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java index 3dad3c69777b..bda8b37da656 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java @@ -13,26 +13,26 @@ */ package io.trino.sql.planner.assertions; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; +import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import java.util.List; @@ -82,21 +82,21 @@ protected Boolean visitConstant(Constant actual, Expression expectedExpression) return false; } - return Objects.equals(actual.getValue(), expected.getValue()) && - actual.getType().equals(expected.getType()); + return Objects.equals(actual.value(), expected.value()) && + actual.type().equals(expected.type()); } @Override - protected Boolean visitSymbolReference(SymbolReference actual, Expression expectedExpression) + protected Boolean visitReference(Reference actual, Expression expectedExpression) { - if (!(expectedExpression instanceof SymbolReference expected)) { + if (!(expectedExpression instanceof Reference expected)) { return false; } // TODO: verify types. This is currently hard to do because planner tests // are either missing types, have the wrong types, or they are unable to // provide types due to limitations in the matcher infrastructure - return symbolAliases.get(expected.getName()).name().equals(actual.name()); + return symbolAliases.get(expected.name()).name().equals(actual.name()); } @Override @@ -110,109 +110,109 @@ protected Boolean visitCast(Cast actual, Expression expectedExpression) // Here we're trying to verify its IR counterpart, but the plan testing framework goes directly // from SQL text -> IR-like expressions without doing all the proper canonicalizations. So we cheat // here and normalize everything to the same case before comparing - if (!actual.getType().toString().equalsIgnoreCase(expected.getType().toString())) { + if (!actual.type().toString().equalsIgnoreCase(expected.type().toString())) { return false; } - return process(actual.getExpression(), expected.getExpression()); + return process(actual.expression(), expected.expression()); } @Override - protected Boolean visitIsNullPredicate(IsNullPredicate actual, Expression expectedExpression) + protected Boolean visitIsNull(IsNull actual, Expression expectedExpression) { - if (!(expectedExpression instanceof IsNullPredicate expected)) { + if (!(expectedExpression instanceof IsNull expected)) { return false; } - return process(actual.getValue(), expected.getValue()); + return process(actual.value(), expected.value()); } @Override - protected Boolean visitInPredicate(InPredicate actual, Expression expectedExpression) + protected Boolean visitIn(In actual, Expression expectedExpression) { - if (!(expectedExpression instanceof InPredicate expected)) { + if (!(expectedExpression instanceof In expected)) { return false; } - return process(actual.getValue(), expected.getValue()) && - process(actual.getValueList(), expected.getValueList()); + return process(actual.value(), expected.value()) && + process(actual.valueList(), expected.valueList()); } @Override - protected Boolean visitComparisonExpression(ComparisonExpression actual, Expression expectedExpression) + protected Boolean visitComparison(Comparison actual, Expression expectedExpression) { - if (!(expectedExpression instanceof ComparisonExpression expected)) { + if (!(expectedExpression instanceof Comparison expected)) { return false; } - if (actual.getOperator() == expected.getOperator() && - process(actual.getLeft(), expected.getLeft()) && - process(actual.getRight(), expected.getRight())) { + if (actual.operator() == expected.operator() && + process(actual.left(), expected.left()) && + process(actual.right(), expected.right())) { return true; } - return actual.getOperator() == expected.getOperator().flip() && - process(actual.getLeft(), expected.getRight()) && - process(actual.getRight(), expected.getLeft()); + return actual.operator() == expected.operator().flip() && + process(actual.left(), expected.right()) && + process(actual.right(), expected.left()); } @Override - protected Boolean visitBetweenPredicate(BetweenPredicate actual, Expression expectedExpression) + protected Boolean visitBetween(Between actual, Expression expectedExpression) { - if (!(expectedExpression instanceof BetweenPredicate expected)) { + if (!(expectedExpression instanceof Between expected)) { return false; } - return process(actual.getValue(), expected.getValue()) && - process(actual.getMin(), expected.getMin()) && - process(actual.getMax(), expected.getMax()); + return process(actual.value(), expected.value()) && + process(actual.min(), expected.min()) && + process(actual.max(), expected.max()); } @Override - protected Boolean visitArithmeticNegation(ArithmeticNegation actual, Expression expectedExpression) + protected Boolean visitNegation(Negation actual, Expression expectedExpression) { - if (!(expectedExpression instanceof ArithmeticNegation expected)) { + if (!(expectedExpression instanceof Negation expected)) { return false; } - return process(actual.getValue(), expected.getValue()); + return process(actual.value(), expected.value()); } @Override - protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Expression expectedExpression) + protected Boolean visitArithmetic(Arithmetic actual, Expression expectedExpression) { - if (!(expectedExpression instanceof ArithmeticBinaryExpression expected)) { + if (!(expectedExpression instanceof Arithmetic expected)) { return false; } - return actual.getOperator() == expected.getOperator() && - process(actual.getLeft(), expected.getLeft()) && - process(actual.getRight(), expected.getRight()); + return actual.operator() == expected.operator() && + process(actual.left(), expected.left()) && + process(actual.right(), expected.right()); } @Override - protected Boolean visitNotExpression(NotExpression actual, Expression expectedExpression) + protected Boolean visitNot(Not actual, Expression expectedExpression) { - if (!(expectedExpression instanceof NotExpression expected)) { + if (!(expectedExpression instanceof Not expected)) { return false; } - return process(actual.getValue(), expected.getValue()); + return process(actual.value(), expected.value()); } @Override - protected Boolean visitLogicalExpression(LogicalExpression actual, Expression expectedExpression) + protected Boolean visitLogical(Logical actual, Expression expectedExpression) { - if (!(expectedExpression instanceof LogicalExpression expected)) { + if (!(expectedExpression instanceof Logical expected)) { return false; } - if (actual.getTerms().size() != expected.getTerms().size() || actual.getOperator() != expected.getOperator()) { + if (actual.terms().size() != expected.terms().size() || actual.operator() != expected.operator()) { return false; } - for (int i = 0; i < actual.getTerms().size(); i++) { - if (!process(actual.getTerms().get(i), expected.getTerms().get(i))) { + for (int i = 0; i < actual.terms().size(); i++) { + if (!process(actual.terms().get(i), expected.terms().get(i))) { return false; } } @@ -221,18 +221,18 @@ protected Boolean visitLogicalExpression(LogicalExpression actual, Expression ex } @Override - protected Boolean visitCoalesceExpression(CoalesceExpression actual, Expression expectedExpression) + protected Boolean visitCoalesce(Coalesce actual, Expression expectedExpression) { - if (!(expectedExpression instanceof CoalesceExpression expected)) { + if (!(expectedExpression instanceof Coalesce expected)) { return false; } - if (actual.getOperands().size() != expected.getOperands().size()) { + if (actual.operands().size() != expected.operands().size()) { return false; } - for (int i = 0; i < actual.getOperands().size(); i++) { - if (!process(actual.getOperands().get(i), expected.getOperands().get(i))) { + for (int i = 0; i < actual.operands().size(); i++) { + if (!process(actual.operands().get(i), expected.operands().get(i))) { return false; } } @@ -240,33 +240,33 @@ protected Boolean visitCoalesceExpression(CoalesceExpression actual, Expression } @Override - protected Boolean visitSimpleCaseExpression(SimpleCaseExpression actual, Expression expectedExpression) + protected Boolean visitSwitch(Switch actual, Expression expectedExpression) { - if (!(expectedExpression instanceof SimpleCaseExpression expected)) { + if (!(expectedExpression instanceof Switch expected)) { return false; } - return process(actual.getOperand(), expected.getOperand()) && - processWhenClauses(actual.getWhenClauses(), expected.getWhenClauses()) && - process(actual.getDefaultValue(), expected.getDefaultValue()); + return process(actual.operand(), expected.operand()) && + processWhenClauses(actual.whenClauses(), expected.whenClauses()) && + process(actual.defaultValue(), expected.defaultValue()); } @Override - protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Expression expected) + protected Boolean visitCase(Case actual, Expression expected) { - if (!(expected instanceof SearchedCaseExpression expectedCase)) { + if (!(expected instanceof Case expectedCase)) { return false; } - if (!processWhenClauses(actual.getWhenClauses(), expectedCase.getWhenClauses())) { + if (!processWhenClauses(actual.whenClauses(), expectedCase.whenClauses())) { return false; } - if (actual.getDefaultValue().isPresent() != expectedCase.getDefaultValue().isPresent()) { + if (actual.defaultValue().isPresent() != expectedCase.defaultValue().isPresent()) { return false; } - return process(actual.getDefaultValue(), expectedCase.getDefaultValue()); + return process(actual.defaultValue(), expectedCase.defaultValue()); } private boolean processWhenClauses(List actual, List expected) @@ -289,29 +289,29 @@ private boolean process(WhenClause actual, WhenClause expected) } @Override - protected Boolean visitFunctionCall(FunctionCall actual, Expression expectedExpression) + protected Boolean visitCall(Call actual, Expression expectedExpression) { - if (!(expectedExpression instanceof FunctionCall expected)) { + if (!(expectedExpression instanceof Call expected)) { return false; } - return actual.getFunction().getName().equals(expected.getFunction().getName()) && - process(actual.getArguments(), expected.getArguments()); + return actual.function().getName().equals(expected.function().getName()) && + process(actual.arguments(), expected.arguments()); } @Override - protected Boolean visitLambdaExpression(LambdaExpression actual, Expression expected) + protected Boolean visitLambda(Lambda actual, Expression expected) { - if (!(expected instanceof LambdaExpression lambdaExpression)) { + if (!(expected instanceof Lambda lambda)) { return false; } // todo this should allow the arguments to have different names - if (!actual.getArguments().equals(lambdaExpression.getArguments())) { + if (!actual.arguments().equals(lambda.arguments())) { return false; } - return process(actual.getBody(), lambdaExpression.getBody()); + return process(actual.body(), lambda.body()); } @Override @@ -321,17 +321,17 @@ protected Boolean visitRow(Row actual, Expression expectedExpression) return false; } - return process(actual.getItems(), expected.getItems()); + return process(actual.items(), expected.items()); } @Override - protected Boolean visitSubscriptExpression(SubscriptExpression actual, Expression expectedExpression) + protected Boolean visitSubscript(Subscript actual, Expression expectedExpression) { - if (!(expectedExpression instanceof SubscriptExpression expected)) { + if (!(expectedExpression instanceof Subscript expected)) { return false; } - return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + return process(actual.base(), expected.base()) && process(actual.index(), expected.index()); } private boolean process(List actuals, List expecteds) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/JoinMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/JoinMatcher.java index 813ac8555c4d..1edb36a31273 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/JoinMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/JoinMatcher.java @@ -21,10 +21,10 @@ import io.trino.metadata.Metadata; import io.trino.spi.type.Type; import io.trino.sql.DynamicFilters; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.FilterNode; @@ -45,8 +45,8 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.operator.join.JoinUtils.getJoinDynamicFilters; import static io.trino.sql.DynamicFilters.extractDynamicFilters; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; import static io.trino.sql.planner.assertions.PlanMatchPattern.DynamicFilterPattern; @@ -176,10 +176,10 @@ private boolean matchDynamicFilters(JoinNode joinNode, SymbolAliases symbolAlias } Expression expression; if (descriptor.isNullAllowed()) { - expression = new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, probe, build.toSymbolReference())); + expression = new Not(new Comparison(IS_DISTINCT_FROM, probe, build.toSymbolReference())); } else { - expression = new ComparisonExpression(descriptor.getOperator(), probe, build.toSymbolReference()); + expression = new Comparison(descriptor.getOperator(), probe, build.toSymbolReference()); } actual.add(expression); } @@ -258,7 +258,7 @@ public Builder dynamicFilter(Map expectedDynamicFilter) @CanIgnoreReturnValue public Builder dynamicFilter(Type type, String key, String value) { - this.dynamicFilter = Optional.of(ImmutableList.of(new PlanMatchPattern.DynamicFilterPattern(new SymbolReference(type, key), EQUAL, value))); + this.dynamicFilter = Optional.of(ImmutableList.of(new PlanMatchPattern.DynamicFilterPattern(new Reference(type, key), EQUAL, value))); return this; } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java index 371de91099c4..d576f0a9ebb8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java @@ -13,7 +13,7 @@ */ package io.trino.sql.planner.assertions; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import static java.util.Objects.requireNonNull; @@ -29,10 +29,10 @@ public static MatchResult match() return new MatchResult(true, new SymbolAliases()); } - public static MatchResult match(String alias, SymbolReference symbolReference) + public static MatchResult match(String alias, Reference reference) { SymbolAliases newAliases = SymbolAliases.builder() - .put(alias, symbolReference) + .put(alias, reference) .build(); return new MatchResult(true, newAliases); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 8da9ea1ffb79..fe5f6df6f1c3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -26,9 +26,9 @@ import io.trino.spi.connector.SortOrder; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.Not; import io.trino.sql.ir.Row; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.Symbol; @@ -92,7 +92,7 @@ import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_LAST; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; import static io.trino.sql.planner.assertions.MatchResult.match; import static io.trino.sql.planner.assertions.StrictAssignedSymbolsMatcher.actualAssignments; @@ -1188,11 +1188,11 @@ public static GroupingSetDescriptor singleGroupingSet(List groupingKeys) public static class DynamicFilterPattern { private final Expression probe; - private final ComparisonExpression.Operator operator; + private final Comparison.Operator operator; private final SymbolAlias build; private final boolean nullAllowed; - public DynamicFilterPattern(Expression probe, ComparisonExpression.Operator operator, String buildAlias, boolean nullAllowed) + public DynamicFilterPattern(Expression probe, Comparison.Operator operator, String buildAlias, boolean nullAllowed) { this.probe = requireNonNull(probe, "probe is null"); this.operator = requireNonNull(operator, "operator is null"); @@ -1200,7 +1200,7 @@ public DynamicFilterPattern(Expression probe, ComparisonExpression.Operator oper this.nullAllowed = nullAllowed; } - public DynamicFilterPattern(Expression probe, ComparisonExpression.Operator operator, String buildAlias) + public DynamicFilterPattern(Expression probe, Comparison.Operator operator, String buildAlias) { this(probe, operator, buildAlias, false); } @@ -1209,13 +1209,13 @@ Expression getExpression(SymbolAliases aliases) { Expression probeMapped = symbolMapper(aliases).map(probe); if (nullAllowed) { - return new NotExpression( - new ComparisonExpression( + return new Not( + new Comparison( IS_DISTINCT_FROM, probeMapped, build.toSymbol(aliases).toSymbolReference())); } - return new ComparisonExpression( + return new Comparison( operator, probeMapped, build.toSymbol(aliases).toSymbolReference()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SetExpressionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SetExpressionMatcher.java index ff95fa24b354..d4165c56eef3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SetExpressionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SetExpressionMatcher.java @@ -74,7 +74,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada private boolean matches(SymbolAliases aliases, Symbol expected, Symbol actual) { - return aliases.get(expected.getName()).getName().equals(actual.getName()); + return aliases.get(expected.getName()).name().equals(actual.getName()); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SymbolAliases.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SymbolAliases.java index 5866fdbcc101..f7422481007d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SymbolAliases.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SymbolAliases.java @@ -17,7 +17,7 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.Assignments; @@ -32,14 +32,14 @@ public final class SymbolAliases { - private final Map map; + private final Map map; public SymbolAliases() { this.map = ImmutableMap.of(); } - private SymbolAliases(Map aliases) + private SymbolAliases(Map aliases) { this.map = ImmutableMap.copyOf(requireNonNull(aliases, "aliases is null")); } @@ -48,9 +48,9 @@ public Expression rewrite(Expression expression) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<>() { @Override - public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteReference(Reference node, Void context, ExpressionTreeRewriter treeRewriter) { - return map.getOrDefault(node.getName(), node); + return map.getOrDefault(node.name(), node); } }, expression); } @@ -64,7 +64,7 @@ public SymbolAliases withNewAliases(SymbolAliases sourceAliases) { Builder builder = new Builder(this); - for (Map.Entry alias : sourceAliases.map.entrySet()) { + for (Map.Entry alias : sourceAliases.map.entrySet()) { builder.put(alias.getKey(), alias.getValue()); } @@ -76,7 +76,7 @@ public Symbol getSymbol(String alias) return Symbol.from(get(alias)); } - public SymbolReference get(String alias) + public Reference get(String alias) { /* * It's still kind of an open question if the right combination of anyTree() and @@ -92,17 +92,17 @@ public SymbolReference get(String alias) return getOptional(alias).orElseThrow(() -> new IllegalStateException(format("missing expression for alias %s", alias))); } - public Optional getOptional(String alias) + public Optional getOptional(String alias) { - SymbolReference result = map.get(alias); + Reference result = map.get(alias); return Optional.ofNullable(result); } - private Map getUpdatedAssignments(Assignments assignments) + private Map getUpdatedAssignments(Assignments assignments) { - ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); + ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); for (Map.Entry assignment : assignments.getMap().entrySet()) { - for (Map.Entry existingAlias : map.entrySet()) { + for (Map.Entry existingAlias : map.entrySet()) { if (assignment.getValue().equals(existingAlias.getValue())) { // Simple symbol rename mapUpdate.put(existingAlias.getKey(), assignment.getKey().toSymbolReference()); @@ -191,7 +191,7 @@ public String toString() public static class Builder { - Map bindings; + Map bindings; private Builder() { @@ -203,22 +203,22 @@ private Builder(SymbolAliases initialAliases) bindings = new HashMap<>(initialAliases.map); } - public Builder put(String alias, SymbolReference symbolReference) + public Builder put(String alias, Reference reference) { requireNonNull(alias, "alias is null"); - requireNonNull(symbolReference, "symbolReference is null"); + requireNonNull(reference, "symbolReference is null"); // Special case to allow identity binding (i.e. "ALIAS" -> expression("ALIAS")) - if (bindings.containsKey(alias) && bindings.get(alias).equals(symbolReference)) { + if (bindings.containsKey(alias) && bindings.get(alias).equals(reference)) { return this; } - checkState(!bindings.containsKey(alias), "Alias '%s' already bound to expression '%s'. Tried to rebind to '%s'", alias, bindings.get(alias), symbolReference); - bindings.put(alias, symbolReference); + checkState(!bindings.containsKey(alias), "Alias '%s' already bound to expression '%s'. Tried to rebind to '%s'", alias, bindings.get(alias), reference); + bindings.put(alias, reference); return this; } - public Builder putAll(Map aliases) + public Builder putAll(Map aliases) { aliases.entrySet() .forEach(entry -> put(entry.getKey(), entry.getValue())); @@ -230,7 +230,7 @@ public Builder putAll(Map aliases) * update existing bindings that have already been added. Unless you're * certain you want this behavior, you don't want it. */ - private Builder putUnchecked(Map aliases) + private Builder putUnchecked(Map aliases) { bindings.putAll(aliases); return this; diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java index 39002609af19..292d2abf6229 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java @@ -24,7 +24,7 @@ import io.trino.spi.function.table.DescriptorArgument; import io.trino.spi.function.table.ScalarArgument; import io.trino.spi.function.table.TableArgument; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; @@ -128,10 +128,10 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { if (!specificationMatches) { return NO_MATCH; } - Set expectedPassThrough = expectedTableArgument.passThroughSymbols().stream() + Set expectedPassThrough = expectedTableArgument.passThroughSymbols().stream() .map(symbolAliases::get) .collect(toImmutableSet()); - Set actualPassThrough = argumentProperties.getPassThroughSpecification().columns().stream() + Set actualPassThrough = argumentProperties.getPassThroughSpecification().columns().stream() .map(PassThroughColumn::symbol) .map(Symbol::toSymbolReference) .collect(toImmutableSet()); @@ -149,7 +149,7 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { return NO_MATCH; } - ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); for (int i = 0; i < properOutputs.size(); i++) { properOutputsMapping.put(properOutputs.get(i), tableFunctionNode.getProperOutputs().get(i).toSymbolReference()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java index 6bb6264ee9f7..b0b6329d019b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -18,7 +18,7 @@ import io.trino.Session; import io.trino.cost.StatsProvider; import io.trino.metadata.Metadata; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; @@ -93,12 +93,12 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - List> expectedPassThrough = passThroughSymbols.stream() + List> expectedPassThrough = passThroughSymbols.stream() .map(list -> list.stream() .map(symbolAliases::get) .collect(toImmutableList())) .collect(toImmutableList()); - List> actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + List> actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() .map(PassThroughSpecification::columns) .map(list -> list.stream() .map(PassThroughColumn::symbol) @@ -109,12 +109,12 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - List> expectedRequired = requiredSymbols.stream() + List> expectedRequired = requiredSymbols.stream() .map(list -> list.stream() .map(symbolAliases::get) .collect(toImmutableList())) .collect(toImmutableList()); - List> actualRequired = tableFunctionProcessorNode.getRequiredSymbols().stream() + List> actualRequired = tableFunctionProcessorNode.getRequiredSymbols().stream() .map(list -> list.stream() .map(Symbol::toSymbolReference) .collect(toImmutableList())) @@ -127,9 +127,9 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } if (markerSymbols.isPresent()) { - Map expectedMapping = markerSymbols.get().entrySet().stream() + Map expectedMapping = markerSymbols.get().entrySet().stream() .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); - Map actualMapping = tableFunctionProcessorNode.getMarkerSymbols().orElseThrow().entrySet().stream() + Map actualMapping = tableFunctionProcessorNode.getMarkerSymbols().orElseThrow().entrySet().stream() .collect(toImmutableMap(entry -> entry.getKey().toSymbolReference(), entry -> entry.getValue().toSymbolReference())); if (!expectedMapping.equals(actualMapping)) { return NO_MATCH; @@ -151,7 +151,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses } } - ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); for (int i = 0; i < properOutputs.size(); i++) { properOutputsMapping.put(properOutputs.get(i), tableFunctionProcessorNode.getProperOutputs().get(i).toSymbolReference()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TestExpressionVerifier.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TestExpressionVerifier.java index 0021ce59b17e..e51ecb7cfe26 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TestExpressionVerifier.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TestExpressionVerifier.java @@ -15,27 +15,27 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; -import io.trino.sql.ir.BetweenPredicate; +import io.trino.sql.ir.Between; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -44,97 +44,97 @@ public class TestExpressionVerifier @Test public void test() { - Expression actual = new NotExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "orderkey"), new Constant(INTEGER, 3L)), new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "custkey"), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "orderkey"), new Constant(INTEGER, 10L))))); + Expression actual = new Not(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "orderkey"), new Constant(INTEGER, 3L)), new Comparison(EQUAL, new Reference(INTEGER, "custkey"), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN, new Reference(INTEGER, "orderkey"), new Constant(INTEGER, 10L))))); SymbolAliases symbolAliases = SymbolAliases.builder() - .put("X", new SymbolReference(INTEGER, "orderkey")) - .put("Y", new SymbolReference(INTEGER, "custkey")) + .put("X", new Reference(INTEGER, "orderkey")) + .put("Y", new Reference(INTEGER, "custkey")) .build(); ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); - assertThat(verifier.process(actual, new NotExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 3L)), new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "Y"), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 10L))))))).isTrue(); - assertThatThrownBy(() -> verifier.process(actual, new NotExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 3L)), new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "Y"), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "Z"), new Constant(INTEGER, 10L))))))) + assertThat(verifier.process(actual, new Not(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "X"), new Constant(INTEGER, 3L)), new Comparison(EQUAL, new Reference(INTEGER, "Y"), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN, new Reference(INTEGER, "X"), new Constant(INTEGER, 10L))))))).isTrue(); + assertThatThrownBy(() -> verifier.process(actual, new Not(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "X"), new Constant(INTEGER, 3L)), new Comparison(EQUAL, new Reference(INTEGER, "Y"), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN, new Reference(INTEGER, "Z"), new Constant(INTEGER, 10L))))))) .isInstanceOf(IllegalStateException.class) .hasMessage("missing expression for alias Z"); - assertThat(verifier.process(actual, new NotExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 3L)), new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 10L))))))).isFalse(); + assertThat(verifier.process(actual, new Not(new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "X"), new Constant(INTEGER, 3L)), new Comparison(EQUAL, new Reference(INTEGER, "X"), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN, new Reference(INTEGER, "X"), new Constant(INTEGER, 10L))))))).isFalse(); } @Test public void testCast() { SymbolAliases aliases = SymbolAliases.builder() - .put("X", new SymbolReference(BIGINT, "orderkey")) + .put("X", new Reference(BIGINT, "orderkey")) .build(); ExpressionVerifier verifier = new ExpressionVerifier(aliases); assertThat(verifier.process(new Constant(VARCHAR, Slices.utf8Slice("2")), new Constant(VARCHAR, Slices.utf8Slice("2")))).isTrue(); assertThat(verifier.process(new Constant(VARCHAR, Slices.utf8Slice("2")), new Cast(new Constant(VARCHAR, Slices.utf8Slice("2")), BIGINT))).isFalse(); - assertThat(verifier.process(new Cast(new SymbolReference(BIGINT, "orderkey"), VARCHAR), new Cast(new SymbolReference(BIGINT, "X"), VARCHAR))).isTrue(); + assertThat(verifier.process(new Cast(new Reference(BIGINT, "orderkey"), VARCHAR), new Cast(new Reference(BIGINT, "X"), VARCHAR))).isTrue(); } @Test public void testBetween() { SymbolAliases symbolAliases = SymbolAliases.builder() - .put("X", new SymbolReference(BIGINT, "orderkey")) - .put("Y", new SymbolReference(BIGINT, "custkey")) + .put("X", new Reference(BIGINT, "orderkey")) + .put("Y", new Reference(BIGINT, "custkey")) .build(); ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); // Complete match - assertThat(verifier.process(new BetweenPredicate(new SymbolReference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new BetweenPredicate(new SymbolReference(INTEGER, "X"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isTrue(); + assertThat(verifier.process(new Between(new Reference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Between(new Reference(INTEGER, "X"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isTrue(); // Different value - assertThat(verifier.process(new BetweenPredicate(new SymbolReference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new BetweenPredicate(new SymbolReference(BIGINT, "Y"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isFalse(); - assertThat(verifier.process(new BetweenPredicate(new SymbolReference(BIGINT, "custkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new BetweenPredicate(new SymbolReference(BIGINT, "X"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isFalse(); + assertThat(verifier.process(new Between(new Reference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Between(new Reference(BIGINT, "Y"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isFalse(); + assertThat(verifier.process(new Between(new Reference(BIGINT, "custkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Between(new Reference(BIGINT, "X"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isFalse(); // Different min or max - assertThat(verifier.process(new BetweenPredicate(new SymbolReference(BIGINT, "orderkey"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), new BetweenPredicate(new SymbolReference(BIGINT, "X"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isFalse(); - assertThat(verifier.process(new BetweenPredicate(new SymbolReference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new BetweenPredicate(new SymbolReference(BIGINT, "X"), new Constant(VARCHAR, Slices.utf8Slice("1")), new Constant(VARCHAR, Slices.utf8Slice("2"))))).isFalse(); - assertThat(verifier.process(new BetweenPredicate(new SymbolReference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new BetweenPredicate(new SymbolReference(BIGINT, "X"), new Constant(INTEGER, 4L), new Constant(INTEGER, 7L)))).isFalse(); + assertThat(verifier.process(new Between(new Reference(BIGINT, "orderkey"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), new Between(new Reference(BIGINT, "X"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))).isFalse(); + assertThat(verifier.process(new Between(new Reference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Between(new Reference(BIGINT, "X"), new Constant(VARCHAR, Slices.utf8Slice("1")), new Constant(VARCHAR, Slices.utf8Slice("2"))))).isFalse(); + assertThat(verifier.process(new Between(new Reference(BIGINT, "orderkey"), new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new Between(new Reference(BIGINT, "X"), new Constant(INTEGER, 4L), new Constant(INTEGER, 7L)))).isFalse(); } @Test public void testSymmetry() { SymbolAliases symbolAliases = SymbolAliases.builder() - .put("a", new SymbolReference(BIGINT, "x")) - .put("b", new SymbolReference(BIGINT, "y")) + .put("a", new Reference(BIGINT, "x")) + .put("b", new Reference(BIGINT, "y")) .build(); ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - - assertThat(verifier.process(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isFalse(); - assertThat(verifier.process(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isFalse(); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isFalse(); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isFalse(); - - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - - assertThat(verifier.process(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isFalse(); - assertThat(verifier.process(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isFalse(); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isFalse(); - assertThat(verifier.process(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isFalse(); - - assertThat(verifier.process(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - - assertThat(verifier.process(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y")), new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))).isTrue(); - assertThat(verifier.process(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x")), new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(GREATER_THAN, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(GREATER_THAN, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(LESS_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(LESS_THAN, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(LESS_THAN, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(LESS_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + + assertThat(verifier.process(new Comparison(LESS_THAN, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isFalse(); + assertThat(verifier.process(new Comparison(LESS_THAN, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(LESS_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isFalse(); + assertThat(verifier.process(new Comparison(GREATER_THAN, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isFalse(); + assertThat(verifier.process(new Comparison(GREATER_THAN, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(LESS_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isFalse(); + + assertThat(verifier.process(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + + assertThat(verifier.process(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isFalse(); + assertThat(verifier.process(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isFalse(); + assertThat(verifier.process(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isFalse(); + assertThat(verifier.process(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isFalse(); + + assertThat(verifier.process(new Comparison(EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(NOT_EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(NOT_EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(NOT_EQUAL, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(NOT_EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(NOT_EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(NOT_EQUAL, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(NOT_EQUAL, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(NOT_EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + + assertThat(verifier.process(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "x"), new Reference(BIGINT, "y")), new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); + assertThat(verifier.process(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))).isTrue(); + assertThat(verifier.process(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "y"), new Reference(BIGINT, "x")), new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")))).isTrue(); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFrameMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFrameMatcher.java index a20bcb69122c..5e3ffab69218 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFrameMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFrameMatcher.java @@ -39,7 +39,7 @@ private static boolean matches(Optional expected, Optional actua return false; } - return expected.map(symbol -> aliases.get(symbol.getName()).getName().equals(actual.get().getName())) + return expected.map(symbol -> aliases.get(symbol.getName()).name().equals(actual.get().getName())) .orElse(true); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/TestRuleIndex.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/TestRuleIndex.java index b2b313897302..1a30fdab09ea 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/TestRuleIndex.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/TestRuleIndex.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Booleans; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; @@ -52,7 +52,7 @@ public void testWithPlanNodeHierarchy() .build(); ProjectNode projectNode = planBuilder.project(Assignments.of(), planBuilder.values()); - FilterNode filterNode = planBuilder.filter(BooleanLiteral.TRUE_LITERAL, planBuilder.values()); + FilterNode filterNode = planBuilder.filter(Booleans.TRUE, planBuilder.values()); ValuesNode valuesNode = planBuilder.values(); assertThat(ruleIndex.getCandidates(projectNode).collect(toSet())).isEqualTo(ImmutableSet.of(projectRule1, projectRule2, anyRule)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index 1e27b581b6c1..4dcedebccd87 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.AggregationNode; @@ -37,13 +37,13 @@ public void testSessionDisable() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "b"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "b"))), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); })) @@ -59,13 +59,13 @@ public void testWithGroups() .on(p -> p.aggregation(af -> { af.singleGroupingSet(p.symbol("c")) .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "b"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "b"))), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.singleGroupingSet(p.symbol("b")) .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); })) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java index f0206b767253..b7033c93ff5e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java @@ -30,9 +30,9 @@ import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.RuleTester; import io.trino.testing.PlanTester; @@ -47,7 +47,7 @@ import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -159,7 +159,7 @@ public void testMismatchedTypesWithCoercion() ImmutableMap.of(column, SOURCE_COLUMN_HANDLE_A)); }) .matches( - project(ImmutableMap.of("COL", expression(new Cast(new SymbolReference(BIGINT, "DEST_COL"), VARCHAR))), + project(ImmutableMap.of("COL", expression(new Cast(new Reference(BIGINT, "DEST_COL"), VARCHAR))), tableScan( new MockConnectorTableHandle(DESTINATION_TABLE)::equals, TupleDomain.all(), @@ -237,7 +237,7 @@ public void testApplyTableScanRedirectionWithFilter() }) .matches( filter( - new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "DEST_COL"), new Constant(VARCHAR, utf8Slice("foo"))), + new Comparison(EQUAL, new Reference(VARCHAR, "DEST_COL"), new Constant(VARCHAR, utf8Slice("foo"))), tableScan( new MockConnectorTableHandle(DESTINATION_TABLE)::equals, TupleDomain.all(), @@ -255,9 +255,9 @@ public void testApplyTableScanRedirectionWithFilter() }) .matches( project( - ImmutableMap.of("expr", expression(new SymbolReference(BIGINT, "DEST_COL_B"))), + ImmutableMap.of("expr", expression(new Reference(BIGINT, "DEST_COL_B"))), filter( - new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "DEST_COL_A"), new Constant(VARCHAR, utf8Slice("foo"))), + new Comparison(EQUAL, new Reference(VARCHAR, "DEST_COL_A"), new Constant(VARCHAR, utf8Slice("foo"))), tableScan( new MockConnectorTableHandle(DESTINATION_TABLE)::equals, TupleDomain.all(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java index 278787671902..6cd063704292 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java @@ -18,10 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.ArrayType; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LambdaExpression; +import io.trino.sql.ir.Lambda; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -50,22 +50,22 @@ public class TestArraySortAfterArrayDistinct public void testArrayDistinctAfterArraySort() { test( - new FunctionCall(DISTINCT, ImmutableList.of(new FunctionCall(SORT, ImmutableList.of(new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))))), - new FunctionCall(SORT, ImmutableList.of(new FunctionCall(DISTINCT, ImmutableList.of(new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a"))))))))); + new Call(DISTINCT, ImmutableList.of(new Call(SORT, ImmutableList.of(new Call(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))))), + new Call(SORT, ImmutableList.of(new Call(DISTINCT, ImmutableList.of(new Call(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a"))))))))); } @Test public void testArrayDistinctAfterArraySortWithLambda() { test( - new FunctionCall(DISTINCT, ImmutableList.of( - new FunctionCall(SORT_WITH_LAMBDA, ImmutableList.of( - new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))), - new LambdaExpression(ImmutableList.of(new Symbol(INTEGER, "a"), new Symbol(INTEGER, "b")), new Constant(INTEGER, 1L)))))), - new FunctionCall(SORT_WITH_LAMBDA, ImmutableList.of( - new FunctionCall(DISTINCT, ImmutableList.of( - new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))), - new LambdaExpression(ImmutableList.of(new Symbol(INTEGER, "a"), new Symbol(INTEGER, "b")), new Constant(INTEGER, 1L))))); + new Call(DISTINCT, ImmutableList.of( + new Call(SORT_WITH_LAMBDA, ImmutableList.of( + new Call(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))), + new Lambda(ImmutableList.of(new Symbol(INTEGER, "a"), new Symbol(INTEGER, "b")), new Constant(INTEGER, 1L)))))), + new Call(SORT_WITH_LAMBDA, ImmutableList.of( + new Call(DISTINCT, ImmutableList.of( + new Call(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))), + new Lambda(ImmutableList.of(new Symbol(INTEGER, "a"), new Symbol(INTEGER, "b")), new Constant(INTEGER, 1L))))); } private void test(Expression original, Expression rewritten) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java index 8e2e02745564..2bf8d484a57b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java @@ -20,16 +20,16 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.assertions.SymbolAliases; import io.trino.transaction.TransactionManager; @@ -47,15 +47,15 @@ import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.ExpressionTestUtils.assertExpressionEquals; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.rewrite; @@ -78,96 +78,96 @@ public class TestCanonicalizeExpressionRewriter public void testRewriteIsNotNullPredicate() { assertRewritten( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "x"))), - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "x")))); + new Not(new IsNull(new Reference(BIGINT, "x"))), + new Not(new IsNull(new Reference(BIGINT, "x")))); } @Test public void testRewriteIfExpression() { assertRewritten( - ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L), new Constant(INTEGER, 1L)), - new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L))), Optional.of(new Constant(INTEGER, 1L)))); + ifExpression(new Comparison(EQUAL, new Reference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L), new Constant(INTEGER, 1L)), + new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L))), Optional.of(new Constant(INTEGER, 1L)))); } @Test public void testCanonicalizeArithmetic() { assertRewritten( - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); } @Test public void testCanonicalizeComparison() { assertRewritten( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(NOT_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(NOT_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(NOT_EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(NOT_EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(NOT_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(GREATER_THAN, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(LESS_THAN, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L))); + new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L))); assertRewritten( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a"))); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new Reference(INTEGER, "a"))); assertRewritten( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a"))); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 1L), new Reference(INTEGER, "a"))); } @Test @@ -180,10 +180,10 @@ public void testCanonicalizeRewriteDateFunctionToCast() private static void assertCanonicalizedDate(Type type, String symbolName) { - FunctionCall date = new FunctionCall( + Call date = new Call( PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("date", fromTypes(type)), - ImmutableList.of(new SymbolReference(VARCHAR, symbolName))); - assertRewritten(date, new Cast(new SymbolReference(VARCHAR, symbolName), DATE)); + ImmutableList.of(new Reference(VARCHAR, symbolName))); + assertRewritten(date, new Cast(new Reference(VARCHAR, symbolName), DATE)); } private static void assertRewritten(Expression from, Expression to) @@ -194,11 +194,11 @@ private static void assertRewritten(Expression from, Expression to) }), to, SymbolAliases.builder() - .put("x", new SymbolReference(BIGINT, "x")) - .put("a", new SymbolReference(BIGINT, "a")) - .put("ts", new SymbolReference(createTimestampType(3), "ts")) - .put("tstz", new SymbolReference(createTimestampWithTimeZoneType(3), "tstz")) - .put("v", new SymbolReference(createVarcharType(100), "v")) + .put("x", new Reference(BIGINT, "x")) + .put("a", new Reference(BIGINT, "a")) + .put("ts", new Reference(createTimestampType(3), "ts")) + .put("tstz", new Reference(createTimestampWithTimeZoneType(3), "tstz")) + .put("v", new Reference(createVarcharType(100), "v")) .build()); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java index faa6881884ee..c46d7184f112 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java @@ -16,7 +16,7 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.planner.plan.JoinType.INNER; public class TestCanonicalizeExpressions @@ -27,7 +27,7 @@ public void testDoesNotFireForExpressionsInCanonicalForm() { CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext()); tester().assertThat(canonicalizeExpressions.filterExpressionRewrite()) - .on(p -> p.filter(FALSE_LITERAL, p.values())) + .on(p -> p.filter(FALSE, p.values())) .doesNotFire(); } @@ -45,7 +45,7 @@ public void testDoesNotFireForCanonicalExpressions() { CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext()); tester().assertThat(canonicalizeExpressions.joinExpressionRewrite()) - .on(p -> p.join(INNER, p.values(), p.values(), FALSE_LITERAL)) + .on(p -> p.join(INNER, p.values(), p.values(), FALSE)) .doesNotFire(); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java index f0e2c469838f..f525223a0144 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java @@ -19,14 +19,14 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -41,9 +41,9 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; @@ -126,7 +126,7 @@ public void testTransformCorrelatedUnnest() p.values(p.symbol("corr")), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -143,7 +143,7 @@ public void testTransformCorrelatedUnnest() Optional.empty(), SINGLE, project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -161,7 +161,7 @@ public void testPreexistingMask() p.values(p.symbol("corr"), p.symbol("old_masks")), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT), p.symbol("old_mask")) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT), p.symbol("old_mask")) .source(p.unnest( ImmutableList.of(), ImmutableList.of( @@ -180,9 +180,9 @@ public void testPreexistingMask() Optional.empty(), SINGLE, project( - ImmutableMap.of("new_mask", expression(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "old_mask"), new SymbolReference(BOOLEAN, "mask"))))), + ImmutableMap.of("new_mask", expression(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "old_mask"), new Reference(BOOLEAN, "mask"))))), project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("corr", "old_masks", "unique"), ImmutableList.of( @@ -202,7 +202,7 @@ public void testWithPreexistingOrdinality() p.values(p.symbol("corr")), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -219,7 +219,7 @@ public void testWithPreexistingOrdinality() Optional.empty(), SINGLE, project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -238,11 +238,11 @@ public void testMultipleGlobalAggregations() p.values(p.symbol("corr")), p.aggregation(outerBuilder -> outerBuilder .globalGrouping() - .addAggregation(p.symbol("arbitrary"), PlanBuilder.aggregation("arbitrary", ImmutableList.of(new SymbolReference(BIGINT, "sum"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("arbitrary"), PlanBuilder.aggregation("arbitrary", ImmutableList.of(new Reference(BIGINT, "sum"))), ImmutableList.of(BIGINT)) .source( p.aggregation(innerBuilder -> innerBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -266,7 +266,7 @@ public void testMultipleGlobalAggregations() Optional.empty(), SINGLE, project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -283,10 +283,10 @@ public void testProjectOverGlobalAggregation() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("sum_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(innerBuilder -> innerBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -296,7 +296,7 @@ public void testProjectOverGlobalAggregation() .matches( project( strictProject( - ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "unique", expression(new SymbolReference(BIGINT, "unique")), "sum_1", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "unique", expression(new Reference(BIGINT, "unique")), "sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("unnested_corr"))), @@ -305,7 +305,7 @@ public void testProjectOverGlobalAggregation() Optional.empty(), SINGLE, project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -320,7 +320,7 @@ public void testPreprojectUnnestSymbol() tester().assertThat(new DecorrelateInnerUnnestWithGlobalAggregation()) .on(p -> { Symbol corr = p.symbol("corr", VARCHAR); - FunctionCall regexpExtractAll = new FunctionCall( + Call regexpExtractAll = new Call( tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(corr.toSymbolReference(), new Constant(VARCHAR, Slices.utf8Slice(".")))); @@ -329,7 +329,7 @@ public void testPreprojectUnnestSymbol() p.values(corr), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -349,14 +349,14 @@ public void testPreprojectUnnestSymbol() Optional.empty(), SINGLE, project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("corr", "unique", "char_array"), ImmutableList.of(unnestMapping("char_array", ImmutableList.of("unnested_corr"))), Optional.of("ordinality"), LEFT, project( - ImmutableMap.of("char_array", expression(new FunctionCall(REGEXP_EXTRACT_ALL, ImmutableList.of(new SymbolReference(BIGINT, "corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), + ImmutableMap.of("char_array", expression(new Call(REGEXP_EXTRACT_ALL, ImmutableList.of(new Reference(BIGINT, "corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), assignUniqueId("unique", values("corr")))))))); } @@ -370,22 +370,22 @@ public void testMultipleNodesOverUnnestInSubquery() ImmutableList.of(p.symbol("groups"), p.symbol("numbers")), p.values(p.symbol("groups"), p.symbol("numbers")), p.project( - Assignments.of(p.symbol("sum_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(globalBuilder -> globalBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) .source(p.project( Assignments.builder() - .put(p.symbol("negate"), new ArithmeticNegation(new SymbolReference(BIGINT, "max"))) + .put(p.symbol("negate"), new Negation(new Reference(BIGINT, "max"))) .build(), p.aggregation(groupedBuilder -> groupedBuilder .singleGroupingSet(p.symbol("group")) - .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new SymbolReference(BIGINT, "modulo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "modulo"))), ImmutableList.of(BIGINT)) .source( p.project( Assignments.builder() .putIdentities(ImmutableList.of(p.symbol("group"), p.symbol("number"))) - .put(p.symbol("modulo"), new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "number"), new Constant(INTEGER, 10L))) + .put(p.symbol("modulo"), new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))) .build(), p.unnest( ImmutableList.of(), @@ -398,7 +398,7 @@ public void testMultipleNodesOverUnnestInSubquery() .matches( project( project( - ImmutableMap.of("sum_1", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("groups", "numbers", "unique"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("negated"))), @@ -407,7 +407,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("negated", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "max")))), + ImmutableMap.of("negated", expression(new Negation(new Reference(BIGINT, "max")))), aggregation( singleGroupingSet("groups", "numbers", "unique", "mask", "group"), ImmutableMap.of(Optional.of("max"), aggregationFunction("max", ImmutableList.of("modulo"))), @@ -416,9 +416,9 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("modulo", expression(new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "number"), new Constant(INTEGER, 10L)))), + ImmutableMap.of("modulo", expression(new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))), project( - ImmutableMap.of("mask", expression(new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality"))))), + ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( ImmutableList.of("groups", "numbers", "unique"), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java index a9ebce91e7df..3d30566ab9ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java @@ -19,11 +19,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -37,8 +37,8 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; import static io.trino.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; @@ -120,7 +120,7 @@ public void testTransformCorrelatedUnnest() p.values(p.symbol("corr")), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -152,7 +152,7 @@ public void testWithMask() p.values(p.symbol("corr"), p.symbol("masks")), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT), p.symbol("mask")) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT), p.symbol("mask")) .source(p.unnest( ImmutableList.of(), ImmutableList.of( @@ -189,7 +189,7 @@ public void testWithOrdinality() p.values(p.symbol("corr")), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -221,11 +221,11 @@ public void testMultipleGlobalAggregations() p.values(p.symbol("corr")), p.aggregation(outerBuilder -> outerBuilder .globalGrouping() - .addAggregation(p.symbol("arbitrary"), PlanBuilder.aggregation("arbitrary", ImmutableList.of(new SymbolReference(BIGINT, "sum"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("arbitrary"), PlanBuilder.aggregation("arbitrary", ImmutableList.of(new Reference(BIGINT, "sum"))), ImmutableList.of(BIGINT)) .source( p.aggregation(innerBuilder -> innerBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -262,10 +262,10 @@ public void testProjectOverGlobalAggregation() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("sum_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(innerBuilder -> innerBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -276,9 +276,9 @@ public void testProjectOverGlobalAggregation() project( strictProject( ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unique", expression(new SymbolReference(BIGINT, "unique")), - "sum_1", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "unique", expression(new Reference(BIGINT, "unique")), + "sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("unnested_corr"))), @@ -299,7 +299,7 @@ public void testPreprojectUnnestSymbol() tester().assertThat(new DecorrelateLeftUnnestWithGlobalAggregation()) .on(p -> { Symbol corr = p.symbol("corr", VARCHAR); - FunctionCall regexpExtractAll = new FunctionCall( + Call regexpExtractAll = new Call( tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(corr.toSymbolReference(), new Constant(VARCHAR, Slices.utf8Slice(".")))); @@ -308,7 +308,7 @@ public void testPreprojectUnnestSymbol() p.values(corr), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new SymbolReference(BIGINT, "unnested_char"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "unnested_char"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array"), ImmutableList.of(p.symbol("unnested_char")))), @@ -332,7 +332,7 @@ public void testPreprojectUnnestSymbol() Optional.empty(), LEFT, project( - ImmutableMap.of("char_array", expression(new FunctionCall(REGEXP_EXTRACT_ALL, ImmutableList.of(new SymbolReference(BIGINT, "corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), + ImmutableMap.of("char_array", expression(new Call(REGEXP_EXTRACT_ALL, ImmutableList.of(new Reference(BIGINT, "corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), assignUniqueId("unique", values("corr"))))))); } @@ -346,22 +346,22 @@ public void testMultipleNodesOverUnnestInSubquery() ImmutableList.of(p.symbol("groups"), p.symbol("numbers")), p.values(p.symbol("groups"), p.symbol("numbers")), p.project( - Assignments.of(p.symbol("sum_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("sum_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(globalBuilder -> globalBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) .source(p.project( Assignments.builder() - .put(p.symbol("negate"), new ArithmeticNegation(new SymbolReference(BIGINT, "max"))) + .put(p.symbol("negate"), new Negation(new Reference(BIGINT, "max"))) .build(), p.aggregation(groupedBuilder -> groupedBuilder .singleGroupingSet(p.symbol("group")) - .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new SymbolReference(BIGINT, "modulo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "modulo"))), ImmutableList.of(BIGINT)) .source( p.project( Assignments.builder() .putIdentities(ImmutableList.of(p.symbol("group"), p.symbol("number"))) - .put(p.symbol("modulo"), new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "number"), new Constant(INTEGER, 10L))) + .put(p.symbol("modulo"), new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))) .build(), p.unnest( ImmutableList.of(), @@ -374,7 +374,7 @@ public void testMultipleNodesOverUnnestInSubquery() .matches( project( project( - ImmutableMap.of("sum_1", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("sum_1", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("groups", "numbers", "unique"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("negated"))), @@ -382,7 +382,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("negated", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "max")))), + ImmutableMap.of("negated", expression(new Negation(new Reference(BIGINT, "max")))), aggregation( singleGroupingSet("groups", "numbers", "unique", "group"), ImmutableMap.of(Optional.of("max"), aggregationFunction("max", ImmutableList.of("modulo"))), @@ -391,7 +391,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("modulo", expression(new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "number"), new Constant(INTEGER, 10L)))), + ImmutableMap.of("modulo", expression(new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))), unnest( ImmutableList.of("groups", "numbers", "unique"), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java index f477b8c91d33..e0ebcba67d39 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java @@ -18,12 +18,12 @@ import io.airlift.slice.Slices; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -40,9 +40,9 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; @@ -100,7 +100,7 @@ public void testLeftCorrelatedJoinWithLeftUnnest() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -125,7 +125,7 @@ public void testInnerCorrelatedJoinWithLeftUnnest() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.INNER, - TRUE_LITERAL, + TRUE, p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -150,7 +150,7 @@ public void testInnerCorrelatedJoinWithInnerUnnest() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.INNER, - TRUE_LITERAL, + TRUE, p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -175,7 +175,7 @@ public void testLeftCorrelatedJoinWithInnerUnnest() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -185,8 +185,8 @@ public void testLeftCorrelatedJoinWithInnerUnnest() .matches( project( ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unnested_corr", expression(ifExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality")), new Constant(BIGINT, null), new SymbolReference(BIGINT, "unnested_corr")))), + "corr", expression(new Reference(BIGINT, "corr")), + "unnested_corr", expression(ifExpression(new IsNull(new Reference(BIGINT, "ordinality")), new Constant(BIGINT, null), new Reference(BIGINT, "unnested_corr")))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -203,7 +203,7 @@ public void testEnforceSingleRow() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.INNER, - TRUE_LITERAL, + TRUE, p.enforceSingleRow( p.unnest( ImmutableList.of(), @@ -214,10 +214,10 @@ public void testEnforceSingleRow() .matches( project(// restore semantics of INNER unnest after it was rewritten to LEFT ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unnested_corr", expression(ifExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality")), new Constant(BIGINT, null), new SymbolReference(BIGINT, "unnested_corr")))), + "corr", expression(new Reference(BIGINT, "corr")), + "unnested_corr", expression(ifExpression(new IsNull(new Reference(BIGINT, "ordinality")), new Constant(BIGINT, null), new Reference(BIGINT, "unnested_corr")))), filter( - ifExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 1L)), new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE_LITERAL), + ifExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 1L)), new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE), rowNumber( builder -> builder .partitionBy(ImmutableList.of("unique")) @@ -239,7 +239,7 @@ public void testLimit() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.limit( 5, p.unnest( @@ -251,7 +251,7 @@ public void testLimit() .matches( project( filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), rowNumber( builder -> builder .partitionBy(ImmutableList.of("unique")) @@ -273,7 +273,7 @@ public void testLimitWithTies() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.limit( 5, ImmutableList.of(p.symbol("unnested_corr")), @@ -286,7 +286,7 @@ public void testLimitWithTies() .matches( project( filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "rank_number"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "rank_number"), new Constant(BIGINT, 5L)), window(builder -> builder .specification(specification( ImmutableList.of("unique"), @@ -309,7 +309,7 @@ public void testTopN() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.topN( 5, ImmutableList.of(p.symbol("unnested_corr")), @@ -322,7 +322,7 @@ public void testTopN() .matches( project( filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), window(builder -> builder .specification(specification( ImmutableList.of("unique"), @@ -345,9 +345,9 @@ public void testProject() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.project( - Assignments.of(p.symbol("boolean_result"), new IsNullPredicate(new SymbolReference(BIGINT, "unnested_corr"))), + Assignments.of(p.symbol("boolean_result"), new IsNull(new Reference(BIGINT, "unnested_corr"))), p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -358,10 +358,10 @@ public void testProject() project( project( ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unique", expression(new SymbolReference(BIGINT, "unique")), - "ordinality", expression(new SymbolReference(BIGINT, "ordinality")), - "boolean_result", expression(new IsNullPredicate(new SymbolReference(BIGINT, "unnested_corr")))), + "corr", expression(new Reference(BIGINT, "corr")), + "unique", expression(new Reference(BIGINT, "unique")), + "ordinality", expression(new Reference(BIGINT, "ordinality")), + "boolean_result", expression(new IsNull(new Reference(BIGINT, "unnested_corr")))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -374,9 +374,9 @@ public void testProject() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.project( - Assignments.of(p.symbol("boolean_result"), new IsNullPredicate(new SymbolReference(BIGINT, "unnested_corr"))), + Assignments.of(p.symbol("boolean_result"), new IsNull(new Reference(BIGINT, "unnested_corr"))), p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -386,14 +386,14 @@ public void testProject() .matches( project(// restore semantics of INNER unnest after it was rewritten to LEFT ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "boolean_result", expression(ifExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality")), new Constant(BIGINT, null), new SymbolReference(BOOLEAN, "boolean_result")))), + "corr", expression(new Reference(BIGINT, "corr")), + "boolean_result", expression(ifExpression(new IsNull(new Reference(BIGINT, "ordinality")), new Constant(BIGINT, null), new Reference(BOOLEAN, "boolean_result")))), project(// append projection from the subquery ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unique", expression(new SymbolReference(BIGINT, "unique")), - "ordinality", expression(new SymbolReference(BIGINT, "ordinality")), - "boolean_result", expression(new IsNullPredicate(new SymbolReference(BIGINT, "unnested_corr")))), + "corr", expression(new Reference(BIGINT, "corr")), + "unique", expression(new Reference(BIGINT, "unique")), + "ordinality", expression(new Reference(BIGINT, "ordinality")), + "boolean_result", expression(new IsNull(new Reference(BIGINT, "unnested_corr")))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -410,14 +410,14 @@ public void testDifferentNodesInSubquery() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.enforceSingleRow( p.project( - Assignments.of(p.symbol("integer_result"), ifExpression(new SymbolReference(BOOLEAN, "boolean_result"), new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("integer_result"), ifExpression(new Reference(BOOLEAN, "boolean_result"), new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))), p.limit( 5, p.project( - Assignments.of(p.symbol("boolean_result"), new IsNullPredicate(new SymbolReference(BIGINT, "unnested_corr"))), + Assignments.of(p.symbol("boolean_result"), new IsNull(new Reference(BIGINT, "unnested_corr"))), p.topN( 10, ImmutableList.of(p.symbol("unnested_corr")), @@ -430,25 +430,25 @@ public void testDifferentNodesInSubquery() .matches( project( filter(// enforce single row - ifExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 1L)), new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE_LITERAL), + ifExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 1L)), new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE), project(// second projection ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unique", expression(new SymbolReference(BIGINT, "unique")), - "ordinality", expression(new SymbolReference(BIGINT, "ordinality")), - "row_number", expression(new SymbolReference(BIGINT, "row_number")), - "integer_result", expression(ifExpression(new SymbolReference(BOOLEAN, "boolean_result"), new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "unique", expression(new Reference(BIGINT, "unique")), + "ordinality", expression(new Reference(BIGINT, "ordinality")), + "row_number", expression(new Reference(BIGINT, "row_number")), + "integer_result", expression(ifExpression(new Reference(BOOLEAN, "boolean_result"), new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)))), filter(// limit - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), project(// first projection ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unique", expression(new SymbolReference(BIGINT, "unique")), - "ordinality", expression(new SymbolReference(BIGINT, "ordinality")), - "row_number", expression(new SymbolReference(BIGINT, "row_number")), - "boolean_result", expression(new IsNullPredicate(new SymbolReference(BIGINT, "unnested_corr")))), + "corr", expression(new Reference(BIGINT, "corr")), + "unique", expression(new Reference(BIGINT, "unique")), + "ordinality", expression(new Reference(BIGINT, "ordinality")), + "row_number", expression(new Reference(BIGINT, "row_number")), + "boolean_result", expression(new IsNull(new Reference(BIGINT, "unnested_corr")))), filter(// topN - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 10L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 10L)), window(builder -> builder .specification(specification( ImmutableList.of("unique"), @@ -471,7 +471,7 @@ public void testWithPreexistingOrdinality() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -481,8 +481,8 @@ public void testWithPreexistingOrdinality() .matches( project( ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "unnested_corr", expression(ifExpression(new IsNullPredicate(new SymbolReference(BIGINT, "ordinality")), new Constant(BIGINT, null), new SymbolReference(BIGINT, "unnested_corr")))), + "corr", expression(new Reference(BIGINT, "corr")), + "unnested_corr", expression(ifExpression(new IsNull(new Reference(BIGINT, "ordinality")), new Constant(BIGINT, null), new Reference(BIGINT, "unnested_corr")))), unnest( ImmutableList.of("corr", "unique"), ImmutableList.of(unnestMapping("corr", ImmutableList.of("unnested_corr"))), @@ -497,7 +497,7 @@ public void testPreprojectUnnestSymbol() tester().assertThat(new DecorrelateUnnest(tester().getMetadata())) .on(p -> { Symbol corr = p.symbol("corr", VARCHAR); - FunctionCall regexpExtractAll = new FunctionCall( + Call regexpExtractAll = new Call( tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(corr.toSymbolReference(), new Constant(VARCHAR, Slices.utf8Slice(".")))); @@ -505,7 +505,7 @@ public void testPreprojectUnnestSymbol() ImmutableList.of(corr), p.values(corr), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array"), ImmutableList.of(p.symbol("unnested_char")))), @@ -523,7 +523,7 @@ public void testPreprojectUnnestSymbol() Optional.of("ordinality"), LEFT, project( - ImmutableMap.of("char_array", expression(new FunctionCall(tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(new SymbolReference(BIGINT, "corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), + ImmutableMap.of("char_array", expression(new Call(tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(new Reference(BIGINT, "corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), assignUniqueId("unique", values("corr")))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java index 8463fc09c596..aeb9883f6e5a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java @@ -24,10 +24,10 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -57,9 +57,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.enforceSingleRow; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -216,10 +216,10 @@ private void testReplicateNoEquiCriteria(JoinType joinType) ImmutableList.of(), ImmutableList.of(p.symbol("A1", BIGINT)), ImmutableList.of(p.symbol("B1", BIGINT)), - Optional.of(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "A1"), new SymbolReference(INTEGER, "B1")), new Constant(INTEGER, 100L))))) + Optional.of(new Comparison(GREATER_THAN, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "A1"), new Reference(INTEGER, "B1")), new Constant(INTEGER, 100L))))) .matches( join(joinType, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "A1"), new SymbolReference(INTEGER, "B1")), new Constant(INTEGER, 100L))) + .filter(new Comparison(GREATER_THAN, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "A1"), new Reference(INTEGER, "B1")), new Constant(INTEGER, 100L))) .distributionType(REPLICATED) .left(values(ImmutableMap.of("A1", 0))) .right(values(ImmutableMap.of("B1", 0))))); @@ -696,7 +696,7 @@ public void testReplicatesWhenSourceIsSmall() p.values(new PlanNodeId("valuesA"), aRows, a1), p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), ImmutableList.of(a1), @@ -708,7 +708,7 @@ public void testReplicatesWhenSourceIsSmall() .equiCriteria("A1", "B1") .distributionType(REPLICATED) .left(values(ImmutableMap.of("A1", 0))) - .right(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))))); + .right(filter(TRUE, values(ImmutableMap.of("B1", 0)))))); // same but with join sides reversed assertDetermineJoinDistributionType() @@ -724,7 +724,7 @@ public void testReplicatesWhenSourceIsSmall() INNER, p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), @@ -737,7 +737,7 @@ public void testReplicatesWhenSourceIsSmall() .equiCriteria("A1", "B1") .distributionType(REPLICATED) .left(values(ImmutableMap.of("A1", 0))) - .right(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))))); + .right(filter(TRUE, values(ImmutableMap.of("B1", 0)))))); // only probe side (with small tables) source stats are available, join sides should be flipped assertDetermineJoinDistributionType() @@ -753,7 +753,7 @@ public void testReplicatesWhenSourceIsSmall() LEFT, p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), @@ -766,7 +766,7 @@ public void testReplicatesWhenSourceIsSmall() .equiCriteria("A1", "B1") .distributionType(PARTITIONED) .left(values(ImmutableMap.of("A1", 0))) - .right(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))))); + .right(filter(TRUE, values(ImmutableMap.of("B1", 0)))))); } @Test @@ -803,7 +803,7 @@ public void testFlipWhenSizeDifferenceLarge() p.values(new PlanNodeId("valuesA"), aRows, a1), p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), ImmutableList.of(a1), @@ -815,7 +815,7 @@ public void testFlipWhenSizeDifferenceLarge() .equiCriteria("A1", "B1") .distributionType(PARTITIONED) .left(values(ImmutableMap.of("A1", 0))) - .right(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))))); + .right(filter(TRUE, values(ImmutableMap.of("B1", 0)))))); // same but with join sides reversed assertDetermineJoinDistributionType() @@ -831,7 +831,7 @@ public void testFlipWhenSizeDifferenceLarge() INNER, p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), @@ -844,7 +844,7 @@ public void testFlipWhenSizeDifferenceLarge() .equiCriteria("A1", "B1") .distributionType(PARTITIONED) .left(values(ImmutableMap.of("A1", 0))) - .right(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))))); + .right(filter(TRUE, values(ImmutableMap.of("B1", 0)))))); // Use REPLICATED join type for cross join assertDetermineJoinDistributionType() @@ -860,7 +860,7 @@ public void testFlipWhenSizeDifferenceLarge() INNER, p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(), @@ -871,7 +871,7 @@ public void testFlipWhenSizeDifferenceLarge() .matches( join(INNER, builder -> builder .distributionType(REPLICATED) - .left(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))) + .left(filter(TRUE, values(ImmutableMap.of("B1", 0)))) .right(values(ImmutableMap.of("A1", 0))))); // Don't flip sides when both are similar in size @@ -892,7 +892,7 @@ public void testFlipWhenSizeDifferenceLarge() INNER, p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), aRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), @@ -904,7 +904,7 @@ public void testFlipWhenSizeDifferenceLarge() join(INNER, builder -> builder .equiCriteria("B1", "A1") .distributionType(PARTITIONED) - .left(filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0)))) + .left(filter(TRUE, values(ImmutableMap.of("B1", 0)))) .right(values(ImmutableMap.of("A1", 0))))); } @@ -1011,7 +1011,7 @@ public void testGetApproximateSourceSizeInBytes() .build(), ImmutableList.of( planBuilder.filter( - TRUE_LITERAL, + TRUE, planBuilder.tableScan( ImmutableList.of(sourceSymbol1), ImmutableMap.of(sourceSymbol1, new TestingColumnHandle("col")))), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java index ada3c44a2789..9202ff73cbfa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java @@ -39,7 +39,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -346,7 +346,7 @@ public void testReplicatesWhenSourceIsSmall() p.values(new PlanNodeId("valuesA"), aRows, a1), p.filter( new PlanNodeId("filterB"), - TRUE_LITERAL, + TRUE, p.values(new PlanNodeId("valuesB"), bRows, b1)), a1, b1, @@ -361,7 +361,7 @@ public void testReplicatesWhenSourceIsSmall() "output", Optional.of(REPLICATED), values(ImmutableMap.of("A1", 0)), - filter(TRUE_LITERAL, values(ImmutableMap.of("B1", 0))))); + filter(TRUE, values(ImmutableMap.of("B1", 0))))); } private RuleBuilder assertDetermineSemiJoinDistributionType() diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 5a27ffabf54b..c9dc6f913825 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -18,10 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; @@ -47,7 +47,7 @@ import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -229,14 +229,14 @@ public void testEliminateCrossJoinWithNonIdentityProjections() INNER, p.project( Assignments.of( - a2, new ArithmeticNegation(new SymbolReference(BIGINT, "a1")), - f, new SymbolReference(BIGINT, "f")), + a2, new Negation(new Reference(BIGINT, "a1")), + f, new Reference(BIGINT, "f")), p.join( INNER, p.project( Assignments.of( - a1, new SymbolReference(BIGINT, "a1"), - f, new ArithmeticNegation(new SymbolReference(BIGINT, "b"))), + a1, new Reference(BIGINT, "a1"), + f, new Negation(new Reference(BIGINT, "b"))), p.join( INNER, p.values(a1), @@ -259,18 +259,18 @@ f, new ArithmeticNegation(new SymbolReference(BIGINT, "b"))), .left( strictProject( ImmutableMap.of( - "a2", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "a1"))), - "a1", expression(new SymbolReference(BIGINT, "a1"))), + "a2", expression(new Negation(new Reference(BIGINT, "a1"))), + "a1", expression(new Reference(BIGINT, "a1"))), PlanMatchPattern.values("a1"))) .right( strictProject( ImmutableMap.of( - "e", expression(new SymbolReference(BIGINT, "e"))), + "e", expression(new Reference(BIGINT, "e"))), PlanMatchPattern.values("e"))))) .right(any()))) .right( strictProject( - ImmutableMap.of("f", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "b")))), + ImmutableMap.of("f", expression(new Negation(new Reference(BIGINT, "b")))), PlanMatchPattern.values("b")))))); } @@ -284,9 +284,9 @@ public void testGiveUpOnComplexProjections() values("a1"), values("b")), "a2", - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a1"), new SymbolReference(INTEGER, "b")), + new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a1"), new Reference(INTEGER, "b")), "b", - new SymbolReference(INTEGER, "b")), + new Reference(INTEGER, "b")), values("c"), "a2", "c", "b", "c"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEvaluateZeroSample.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEvaluateZeroSample.java index f5faf5a00112..d87e590cdbab 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEvaluateZeroSample.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEvaluateZeroSample.java @@ -15,15 +15,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.SampleNode.Type; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; public class TestEvaluateZeroSample @@ -50,7 +50,7 @@ public void test() 0, Type.BERNOULLI, p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java index 1c3105ee28a3..254b6d9df633 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java @@ -19,10 +19,10 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -33,7 +33,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.patternRecognition; @@ -57,7 +57,7 @@ protected Expression rewriteExpression(Expression node, Void context, Expression public Expression rewriteRow(Row node, Void context, ExpressionTreeRewriter treeRewriter) { // rewrite Row items to preserve Row structure of ValuesNode - return new Row(node.getItems().stream().map(item -> new Constant(INTEGER, 0L)).collect(toImmutableList())); + return new Row(node.items().stream().map(item -> new Constant(INTEGER, 0L)).collect(toImmutableList())); } }, expression)); @@ -66,7 +66,7 @@ public void testProjectionExpressionRewrite() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.symbol("y"), new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "x")))), + Assignments.of(p.symbol("y"), new Not(new IsNull(new Reference(BIGINT, "x")))), p.values(p.symbol("x")))) .matches( project(ImmutableMap.of("y", expression(new Constant(INTEGER, 0L))), values("x"))); @@ -85,13 +85,13 @@ public void testProjectionExpressionNotRewritten() @Test public void testAggregationExpressionRewrite() { - ExpressionRewriteRuleSet functionCallRewriter = new ExpressionRewriteRuleSet((expression, context) -> new SymbolReference(BIGINT, "y")); + ExpressionRewriteRuleSet functionCallRewriter = new ExpressionRewriteRuleSet((expression, context) -> new Reference(BIGINT, "y")); tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()) .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( p.symbol("count_1", BigintType.BIGINT), - PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "x"))), + PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BigintType.BIGINT)) .source( p.values(p.symbol("x"), p.symbol("y"))))) @@ -149,7 +149,7 @@ public void testPatternRecognitionExpressionRewrite() builder -> builder .addMeasure(p.symbol("measure_1", INTEGER), new Constant(INTEGER, 1L)) .pattern(label("X")) - .addVariableDefinition(label("X"), TRUE_LITERAL) + .addVariableDefinition(label("X"), TRUE) .source(p.values(p.symbol("a", INTEGER))))) .matches( patternRecognition( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java index afeb3d517721..173aa5b8f92c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java @@ -19,11 +19,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.SetOperationOutputMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -35,9 +35,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -93,17 +93,17 @@ public void test() .matches( strictProject( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b"))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b"))), filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "row_number"), new FunctionCall(GREATEST, ImmutableList.of(new ArithmeticBinaryExpression(SUBTRACT_BIGINT, SUBTRACT, new SymbolReference(BIGINT, "count_1"), new SymbolReference(BIGINT, "count_2")), new Constant(BIGINT, 0L)))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Call(GREATEST, ImmutableList.of(new Arithmetic(SUBTRACT_BIGINT, SUBTRACT, new Reference(BIGINT, "count_1"), new Reference(BIGINT, "count_2")), new Constant(BIGINT, 0L)))), strictProject( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "count_1", expression(new SymbolReference(BIGINT, "count_1")), - "count_2", expression(new SymbolReference(BIGINT, "count_2")), - "row_number", expression(new SymbolReference(BIGINT, "row_number"))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "count_1", expression(new Reference(BIGINT, "count_1")), + "count_2", expression(new Reference(BIGINT, "count_2")), + "row_number", expression(new Reference(BIGINT, "row_number"))), window(builder -> builder .specification(specification( ImmutableList.of("a", "b"), @@ -121,17 +121,17 @@ public void test() union( project( ImmutableMap.of( - "a1", expression(new SymbolReference(BIGINT, "a_1")), - "b1", expression(new SymbolReference(BIGINT, "b_1")), - "marker_left_1", expression(TRUE_LITERAL), + "a1", expression(new Reference(BIGINT, "a_1")), + "b1", expression(new Reference(BIGINT, "b_1")), + "marker_left_1", expression(TRUE), "marker_left_2", expression(new Constant(BOOLEAN, null))), values("a_1", "b_1")), project( ImmutableMap.of( - "a2", expression(new SymbolReference(BIGINT, "a_2")), - "b2", expression(new SymbolReference(BIGINT, "b_2")), + "a2", expression(new Reference(BIGINT, "a_2")), + "b2", expression(new Reference(BIGINT, "b_2")), "marker_right_1", expression(new Constant(BOOLEAN, null)), - "marker_right_2", expression(TRUE_LITERAL)), + "marker_right_2", expression(TRUE)), values("a_2", "b_2"))) .withAlias("a", new SetOperationOutputMatcher(0)) .withAlias("b", new SetOperationOutputMatcher(1)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptDistinctAsUnion.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptDistinctAsUnion.java index cc27ed46a36f..911ba560a33d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptDistinctAsUnion.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptDistinctAsUnion.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode; @@ -26,7 +26,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -60,15 +60,15 @@ public void test() union( project( ImmutableMap.of( - "leftValue", expression(new SymbolReference(BIGINT, "a")), - "left_marker_1", expression(TRUE_LITERAL), + "leftValue", expression(new Reference(BIGINT, "a")), + "left_marker_1", expression(TRUE), "left_marker_2", expression(new Constant(BOOLEAN, null))), values("a")), project( ImmutableMap.of( - "rightValue", expression(new SymbolReference(BIGINT, "b")), + "rightValue", expression(new Reference(BIGINT, "b")), "right_marker_1", expression(new Constant(BOOLEAN, null)), - "right_marker_2", expression(TRUE_LITERAL)), + "right_marker_2", expression(TRUE)), values("b"))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java index 284ed3569617..64a355e47b6a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -27,8 +27,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -66,9 +66,9 @@ public void testFilterToMask() Optional.empty(), AggregationNode.Step.SINGLE, filter( - TRUE_LITERAL, + TRUE, project( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "g", expression(new SymbolReference(BIGINT, "g")), "filter", expression(new SymbolReference(BOOLEAN, "filter"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "g", expression(new Reference(BIGINT, "g")), "filter", expression(new Reference(BOOLEAN, "filter"))), values("a", "g", "filter"))))); } @@ -99,14 +99,14 @@ public void testCombineMaskAndFilter() Optional.empty(), AggregationNode.Step.SINGLE, filter( - TRUE_LITERAL, + TRUE, project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "g", expression(new SymbolReference(BIGINT, "g")), - "mask", expression(new SymbolReference(BOOLEAN, "mask")), - "filter", expression(new SymbolReference(BOOLEAN, "filter")), - "new_mask", expression(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "mask"), new SymbolReference(BOOLEAN, "filter"))))), + "a", expression(new Reference(BIGINT, "a")), + "g", expression(new Reference(BIGINT, "g")), + "mask", expression(new Reference(BOOLEAN, "mask")), + "filter", expression(new Reference(BOOLEAN, "filter")), + "new_mask", expression(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "mask"), new Reference(BOOLEAN, "filter"))))), values("a", "g", "mask", "filter"))))); } @@ -135,9 +135,9 @@ public void testWithFilterPushdown() Optional.empty(), AggregationNode.Step.SINGLE, filter( - new SymbolReference(BOOLEAN, "filter"), + new Reference(BOOLEAN, "filter"), project( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "g", expression(new SymbolReference(BIGINT, "g")), "filter", expression(new SymbolReference(BOOLEAN, "filter"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "g", expression(new Reference(BIGINT, "g")), "filter", expression(new Reference(BOOLEAN, "filter"))), values("a", "g", "filter"))))); } @@ -170,9 +170,9 @@ public void testWithMultipleAggregations() Optional.empty(), AggregationNode.Step.SINGLE, filter( - TRUE_LITERAL, + TRUE, project( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "g", expression(new SymbolReference(BIGINT, "g")), "filter", expression(new SymbolReference(BOOLEAN, "filter"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "g", expression(new Reference(BIGINT, "g")), "filter", expression(new Reference(BOOLEAN, "filter"))), values("a", "g", "filter"))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java index 236236895658..82b770708a27 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java @@ -18,10 +18,10 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.SetOperationOutputMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -33,8 +33,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -89,17 +89,17 @@ public void test() .matches( strictProject( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b"))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b"))), filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "row_number"), new FunctionCall(LEAST, ImmutableList.of(new SymbolReference(BIGINT, "count_1"), new SymbolReference(BIGINT, "count_2")))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "row_number"), new Call(LEAST, ImmutableList.of(new Reference(BIGINT, "count_1"), new Reference(BIGINT, "count_2")))), strictProject( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "count_1", expression(new SymbolReference(BIGINT, "count_1")), - "count_2", expression(new SymbolReference(BIGINT, "count_2")), - "row_number", expression(new SymbolReference(BIGINT, "row_number"))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "count_1", expression(new Reference(BIGINT, "count_1")), + "count_2", expression(new Reference(BIGINT, "count_2")), + "row_number", expression(new Reference(BIGINT, "row_number"))), window(builder -> builder .specification(specification( ImmutableList.of("a", "b"), @@ -117,17 +117,17 @@ public void test() union( project( ImmutableMap.of( - "a1", expression(new SymbolReference(BIGINT, "a_1")), - "b1", expression(new SymbolReference(BIGINT, "b_1")), - "marker_left_1", expression(TRUE_LITERAL), + "a1", expression(new Reference(BIGINT, "a_1")), + "b1", expression(new Reference(BIGINT, "b_1")), + "marker_left_1", expression(TRUE), "marker_left_2", expression(new Constant(BOOLEAN, null))), values("a_1", "b_1")), project( ImmutableMap.of( - "a2", expression(new SymbolReference(BIGINT, "a_2")), - "b2", expression(new SymbolReference(BIGINT, "b_2")), + "a2", expression(new Reference(BIGINT, "a_2")), + "b2", expression(new Reference(BIGINT, "b_2")), "marker_right_1", expression(new Constant(BOOLEAN, null)), - "marker_right_2", expression(TRUE_LITERAL)), + "marker_right_2", expression(TRUE)), values("a_2", "b_2"))) .withAlias("a", new SetOperationOutputMatcher(0)) .withAlias("b", new SetOperationOutputMatcher(1)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectDistinctAsUnion.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectDistinctAsUnion.java index 62ea4881e31c..3ce2245386ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectDistinctAsUnion.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectDistinctAsUnion.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode; @@ -26,7 +26,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -60,15 +60,15 @@ public void test() union( project( ImmutableMap.of( - "leftValue", expression(new SymbolReference(BIGINT, "a")), - "left_marker_1", expression(TRUE_LITERAL), + "leftValue", expression(new Reference(BIGINT, "a")), + "left_marker_1", expression(TRUE), "left_marker_2", expression(new Constant(BOOLEAN, null))), values("a")), project( ImmutableMap.of( - "rightValue", expression(new SymbolReference(BIGINT, "b")), + "rightValue", expression(new Reference(BIGINT, "b")), "right_marker_1", expression(new Constant(BOOLEAN, null)), - "right_marker_2", expression(TRUE_LITERAL)), + "right_marker_2", expression(TRUE)), values("b"))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java index 95452729b589..6a7e98e2b485 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java @@ -16,15 +16,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.SortOrder; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; @@ -51,9 +51,9 @@ public void testReplaceLimitWithTies() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b"))), filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "rank_num"), new Constant(BIGINT, 2L)), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "rank_num"), new Constant(BIGINT, 2L)), window( windowMatcherBuilder -> windowMatcherBuilder .specification(specification( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java index 387895e6fffb..71ed5e1fca5e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java @@ -15,16 +15,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.rowNumber; @@ -50,9 +50,9 @@ public void testReplaceOffsetOverValues() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), @@ -75,9 +75,9 @@ public void testReplaceOffsetOverSort() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 2L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java index 8415d550f0c9..4a183b1067b0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java @@ -16,12 +16,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; @@ -41,12 +41,12 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TinyintType.TINYINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; import static io.trino.sql.ir.IrExpressions.ifExpression; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -280,23 +280,23 @@ public void testTwoSourcesWithSetSemantics() .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper symbols for joined nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size")))), join(// join nodes using helper symbols FULL, joinBuilder -> joinBuilder - .filter(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))) + .filter(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -372,39 +372,39 @@ public void testThreeSourcesWithSetSemantics() .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), - "marker_3", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3")), new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), + "marker_3", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3")), new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null)))), project(// append helper symbols for joined nodes ImmutableMap.of( - "combined_row_number_1_2_3", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "input_3_row_number"))), - "combined_partition_size_1_2_3", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_partition_size_1_2"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_partition_size_1_2"), new SymbolReference(BIGINT, "input_3_partition_size")))), + "combined_row_number_1_2_3", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "input_3_row_number"))), + "combined_partition_size_1_2_3", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_partition_size_1_2"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_partition_size_1_2"), new Reference(BIGINT, "input_3_partition_size")))), join(// join nodes using helper symbols FULL, joinBuilder -> joinBuilder - .filter(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "input_3_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "input_3_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "combined_partition_size_1_2")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, 1L))))))) + .filter(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "input_3_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "input_3_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "combined_partition_size_1_2")), + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, 1L))))))) .left(project(// append helper symbols for joined nodes ImmutableMap.of( - "combined_row_number_1_2", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size_1_2", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size")))), + "combined_row_number_1_2", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size_1_2", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size")))), join(// join nodes using helper symbols FULL, nestedJoinBuilder -> nestedJoinBuilder - .filter(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))) + .filter(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -475,26 +475,26 @@ public void testTwoCoPartitionedSources() .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "e")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "e")))), join(// co-partition nodes LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "e"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "e"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -555,26 +555,26 @@ public void testCoPartitionJoinTypes() .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "d")))), join(// co-partition nodes INNER, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "d"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -631,26 +631,26 @@ public void testCoPartitionJoinTypes() .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "d")))), join(// co-partition nodes LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "d"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -707,26 +707,26 @@ public void testCoPartitionJoinTypes() .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_2_partition_size"), new SymbolReference(BIGINT, "input_1_partition_size"))), - "combined_partition_column", expression(new CoalesceExpression(new SymbolReference(BIGINT, "d"), new SymbolReference(BIGINT, "c")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_2_partition_size"), new Reference(BIGINT, "input_1_partition_size"))), + "combined_partition_column", expression(new Coalesce(new Reference(BIGINT, "d"), new Reference(BIGINT, "c")))), join(// co-partition nodes LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "d"), new SymbolReference(BIGINT, "c"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "d"), new Reference(BIGINT, "c"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_2 builder -> builder .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) @@ -783,26 +783,26 @@ public void testCoPartitionJoinTypes() .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "d")))), join(// co-partition nodes FULL, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "d"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -872,45 +872,45 @@ public void testThreeCoPartitionedSources() .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), - "marker_3", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3")), new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), + "marker_3", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3")), new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number_1_2_3", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "input_3_row_number"))), - "combined_partition_size_1_2_3", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_partition_size_1_2"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_partition_size_1_2"), new SymbolReference(BIGINT, "input_3_partition_size"))), - "combined_partition_column_1_2_3", expression(new CoalesceExpression(new SymbolReference(BIGINT, "combined_partition_column_1_2"), new SymbolReference(BIGINT, "e")))), + "combined_row_number_1_2_3", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "input_3_row_number"))), + "combined_partition_size_1_2_3", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_partition_size_1_2"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_partition_size_1_2"), new Reference(BIGINT, "input_3_partition_size"))), + "combined_partition_column_1_2_3", expression(new Coalesce(new Reference(BIGINT, "combined_partition_column_1_2"), new Reference(BIGINT, "e")))), join(// co-partition nodes LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "combined_partition_column_1_2"), new SymbolReference(BIGINT, "e"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "input_3_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "input_3_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "combined_partition_size_1_2")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "combined_partition_column_1_2"), new Reference(BIGINT, "e"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "input_3_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "input_3_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "combined_partition_size_1_2")), + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, 1L))))))))) .left(project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number_1_2", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size_1_2", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column_1_2", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d")))), + "combined_row_number_1_2", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size_1_2", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column_1_2", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "d")))), join(// co-partition nodes INNER, nestedJoinBuilder -> nestedJoinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "d"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -1001,43 +1001,43 @@ public void testTwoCoPartitionLists() .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3_4")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3_4")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), - "marker_3", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3_4")), new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null))), - "marker_4", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_4_row_number"), new SymbolReference(BIGINT, "combined_row_number_1_2_3_4")), new SymbolReference(BIGINT, "input_4_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3_4")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3_4")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), + "marker_3", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3_4")), new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null))), + "marker_4", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_4_row_number"), new Reference(BIGINT, "combined_row_number_1_2_3_4")), new Reference(BIGINT, "input_4_row_number"), new Constant(BIGINT, null)))), project(// append helper symbols for joined nodes ImmutableMap.of( - "combined_row_number_1_2_3_4", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "combined_row_number_3_4"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "combined_row_number_3_4"))), - "combined_partition_size_1_2_3_4", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_partition_size_1_2"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "combined_partition_size_3_4"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_partition_size_1_2"), new SymbolReference(BIGINT, "combined_partition_size_3_4")))), + "combined_row_number_1_2_3_4", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "combined_row_number_3_4"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "combined_row_number_3_4"))), + "combined_partition_size_1_2_3_4", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_partition_size_1_2"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "combined_partition_size_3_4"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_partition_size_1_2"), new Reference(BIGINT, "combined_partition_size_3_4")))), join(// join nodes using helper symbols LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "combined_row_number_3_4")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "combined_row_number_1_2"), new SymbolReference(BIGINT, "combined_partition_size_3_4")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_3_4"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "combined_row_number_3_4"), new SymbolReference(BIGINT, "combined_partition_size_1_2")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, 1L))))))) + .filter(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "combined_row_number_3_4")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "combined_row_number_1_2"), new Reference(BIGINT, "combined_partition_size_3_4")), + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_3_4"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "combined_row_number_3_4"), new Reference(BIGINT, "combined_partition_size_1_2")), + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_1_2"), new Constant(BIGINT, 1L))))))) .left(project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number_1_2", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size_1_2", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column_1_2", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d")))), + "combined_row_number_1_2", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size_1_2", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column_1_2", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "d")))), join(// co-partition nodes INNER, nestedJoinBuilder -> nestedJoinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "d"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "d"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) @@ -1054,22 +1054,22 @@ public void testTwoCoPartitionLists() values("d")))))) .right(project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number_3_4", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_4_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "input_4_row_number"))), - "combined_partition_size_3_4", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_4_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_3_partition_size"), new SymbolReference(BIGINT, "input_4_partition_size"))), - "combined_partition_column_3_4", expression(new CoalesceExpression(new SymbolReference(BIGINT, "e"), new SymbolReference(BIGINT, "f")))), + "combined_row_number_3_4", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_4_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "input_4_row_number"))), + "combined_partition_size_3_4", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_4_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_3_partition_size"), new Reference(BIGINT, "input_4_partition_size"))), + "combined_partition_column_3_4", expression(new Coalesce(new Reference(BIGINT, "e"), new Reference(BIGINT, "f")))), join(// co-partition nodes FULL, nestedJoinBuilder -> nestedJoinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "e"), new SymbolReference(BIGINT, "f"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "input_4_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "input_4_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_4_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_4_row_number"), new SymbolReference(BIGINT, "input_3_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "e"), new Reference(BIGINT, "f"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "input_4_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "input_4_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_4_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_4_row_number"), new Reference(BIGINT, "input_3_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_3 builder -> builder .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) @@ -1139,42 +1139,42 @@ public void testCoPartitionedAndNotCoPartitionedSources() .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number_2_3_1")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number_2_3_1")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), - "marker_3", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "combined_row_number_2_3_1")), new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number_2_3_1")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number_2_3_1")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null))), + "marker_3", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "combined_row_number_2_3_1")), new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, null)))), project(// append helper symbols for joined nodes ImmutableMap.of( - "combined_row_number_2_3_1", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_row_number_2_3"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_row_number_2_3"), new SymbolReference(BIGINT, "input_1_row_number"))), - "combined_partition_size_2_3_1", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "combined_partition_size_2_3"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "combined_partition_size_2_3"), new SymbolReference(BIGINT, "input_1_partition_size")))), + "combined_row_number_2_3_1", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_row_number_2_3"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_row_number_2_3"), new Reference(BIGINT, "input_1_row_number"))), + "combined_partition_size_2_3_1", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "combined_partition_size_2_3"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "combined_partition_size_2_3"), new Reference(BIGINT, "input_1_partition_size")))), join(// join nodes using helper symbols INNER, joinBuilder -> joinBuilder - .filter(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_2_3"), new SymbolReference(BIGINT, "input_1_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "combined_row_number_2_3"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_partition_size_2_3")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "combined_row_number_2_3"), new Constant(BIGINT, 1L))))))) + .filter(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_2_3"), new Reference(BIGINT, "input_1_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "combined_row_number_2_3"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_partition_size_2_3")), + new Comparison(EQUAL, new Reference(BIGINT, "combined_row_number_2_3"), new Constant(BIGINT, 1L))))))) .left(project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number_2_3", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_3_row_number"))), - "combined_partition_size_2_3", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_2_partition_size"), new SymbolReference(BIGINT, "input_3_partition_size"))), - "combined_partition_column_2_3", expression(new CoalesceExpression(new SymbolReference(BIGINT, "d"), new SymbolReference(BIGINT, "e")))), + "combined_row_number_2_3", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_3_row_number"))), + "combined_partition_size_2_3", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_3_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_2_partition_size"), new Reference(BIGINT, "input_3_partition_size"))), + "combined_partition_column_2_3", expression(new Coalesce(new Reference(BIGINT, "d"), new Reference(BIGINT, "e")))), join(// co-partition nodes LEFT, nestedJoinBuilder -> nestedJoinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "d"), new SymbolReference(BIGINT, "e"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_3_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_3_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_3_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "d"), new Reference(BIGINT, "e"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_3_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_3_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_3_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_3_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_2 builder -> builder .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) @@ -1217,9 +1217,9 @@ public void testCoerceForCopartitioning() // coerce column c for co-partitioning p.project( Assignments.builder() - .put(c, new SymbolReference(BIGINT, "c")) - .put(d, new SymbolReference(BIGINT, "d")) - .put(cCoerced, new Cast(new SymbolReference(BIGINT, "c"), INTEGER)) + .put(c, new Reference(BIGINT, "c")) + .put(d, new Reference(BIGINT, "d")) + .put(cCoerced, new Cast(new Reference(BIGINT, "c"), INTEGER)) .build(), p.values(c, d)), p.values(e, f)), @@ -1254,26 +1254,26 @@ public void testCoerceForCopartitioning() .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c_coerced"), new SymbolReference(BIGINT, "e")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column", expression(new Coalesce(new Reference(BIGINT, "c_coerced"), new Reference(BIGINT, "e")))), join(// co-partition nodes LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c_coerced"), new SymbolReference(BIGINT, "e"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c_coerced"), new Reference(BIGINT, "e"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) @@ -1281,7 +1281,7 @@ public void testCoerceForCopartitioning() .addFunction("input_1_partition_size", windowFunction("count", ImmutableList.of(), FULL_FRAME)), // input_1 project( - ImmutableMap.of("c_coerced", expression(new Cast(new SymbolReference(BIGINT, "c"), INTEGER))), + ImmutableMap.of("c_coerced", expression(new Cast(new Reference(BIGINT, "c"), INTEGER))), values("c", "d")))) .right(window(// append helper symbols for source input_2 builder -> builder @@ -1339,26 +1339,26 @@ public void testTwoCoPartitioningColumns() .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper and partitioning symbols for co-partitioned nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size"))), - "combined_partition_column_1", expression(new CoalesceExpression(new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "e"))), - "combined_partition_column_2", expression(new CoalesceExpression(new SymbolReference(BIGINT, "d"), new SymbolReference(BIGINT, "f")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size"))), + "combined_partition_column_1", expression(new Coalesce(new Reference(BIGINT, "c"), new Reference(BIGINT, "e"))), + "combined_partition_column_2", expression(new Coalesce(new Reference(BIGINT, "d"), new Reference(BIGINT, "f")))), join(// co-partition nodes LEFT, joinBuilder -> joinBuilder - .filter(new LogicalExpression(AND, ImmutableList.of( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "c"), new SymbolReference(BIGINT, "e"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BIGINT, "d"), new SymbolReference(BIGINT, "f"))), - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) + .filter(new Logical(AND, ImmutableList.of( + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "c"), new Reference(BIGINT, "e"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BIGINT, "d"), new Reference(BIGINT, "f"))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) @@ -1422,23 +1422,23 @@ public void testTwoSourcesWithRowAndSetSemantics() .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), project(// append marker symbols ImmutableMap.of( - "marker_1", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), - "marker_2", expression(ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "combined_row_number")), new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), + "marker_1", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, null))), + "marker_2", expression(ifExpression(new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "combined_row_number")), new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, null)))), project(// append helper symbols for joined nodes ImmutableMap.of( - "combined_row_number", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number"))), - "combined_partition_size", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new CoalesceExpression(new SymbolReference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new CoalesceExpression(new SymbolReference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new SymbolReference(BIGINT, "input_1_partition_size"), new SymbolReference(BIGINT, "input_2_partition_size")))), + "combined_row_number", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number"))), + "combined_partition_size", expression(ifExpression(new Comparison(GREATER_THAN, new Coalesce(new Reference(BIGINT, "input_1_partition_size"), new Constant(BIGINT, -1L)), new Coalesce(new Reference(BIGINT, "input_2_partition_size"), new Constant(BIGINT, -1L))), new Reference(BIGINT, "input_1_partition_size"), new Reference(BIGINT, "input_2_partition_size")))), join(// join nodes using helper symbols FULL, joinBuilder -> joinBuilder - .filter(new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_row_number")), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_1_row_number"), new SymbolReference(BIGINT, "input_2_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "input_2_row_number"), new SymbolReference(BIGINT, "input_1_partition_size")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))) + .filter(new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_row_number")), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_1_row_number"), new Reference(BIGINT, "input_2_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_2_row_number"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input_2_row_number"), new Reference(BIGINT, "input_1_partition_size")), + new Comparison(EQUAL, new Reference(BIGINT, "input_1_row_number"), new Constant(BIGINT, 1L))))))) .left(window(// append helper symbols for source input_1 builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java index 00d8295b4c33..176002411e15 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java @@ -15,22 +15,22 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -45,17 +45,17 @@ public void testInlineProjection() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new SymbolReference(INTEGER, "a"), + new Reference(INTEGER, "a"), p.project( - Assignments.of(p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b", INTEGER))))) .matches( project( - ImmutableMap.of("a", expression(TRUE_LITERAL)), + ImmutableMap.of("a", expression(TRUE)), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), project( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b"))), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b"))), values("b"))))); tester().assertThat(new InlineProjectIntoFilter()) @@ -65,22 +65,22 @@ public void testInlineProjection() Symbol c = p.symbol("c", INTEGER); Symbol d = p.symbol("d", INTEGER); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(INTEGER, "a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new SymbolReference(INTEGER, "c")))), + new Logical(AND, ImmutableList.of(new Reference(INTEGER, "a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Reference(INTEGER, "c")))), p.project( Assignments.builder() - .put(a, new IsNullPredicate(new SymbolReference(INTEGER, "d"))) - .put(b, new SymbolReference(INTEGER, "b")) - .put(c, new SymbolReference(INTEGER, "c")) + .put(a, new IsNull(new Reference(INTEGER, "d"))) + .put(b, new Reference(INTEGER, "b")) + .put(c, new Reference(INTEGER, "c")) .build(), p.values(b, c, d))); }) .matches( project( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b")), "c", expression(new SymbolReference(INTEGER, "c")), "a", expression(TRUE_LITERAL)), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b")), "c", expression(new Reference(INTEGER, "c")), "a", expression(TRUE)), filter( - new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new SymbolReference(INTEGER, "d")), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new SymbolReference(INTEGER, "c")))), + new Logical(AND, ImmutableList.of(new IsNull(new Reference(INTEGER, "d")), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Reference(INTEGER, "c")))), project( - ImmutableMap.of("d", expression(new SymbolReference(INTEGER, "d")), "b", expression(new SymbolReference(INTEGER, "b")), "c", expression(new SymbolReference(INTEGER, "c"))), + ImmutableMap.of("d", expression(new Reference(INTEGER, "d")), "b", expression(new Reference(INTEGER, "b")), "c", expression(new Reference(INTEGER, "c"))), values("b", "c", "d"))))); } @@ -89,9 +89,9 @@ public void testNoSimpleConjuncts() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(INTEGER, "a"), FALSE_LITERAL)), + new Logical(OR, ImmutableList.of(new Reference(INTEGER, "a"), FALSE)), p.project( - Assignments.of(p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b"))))) .doesNotFire(); } @@ -101,17 +101,17 @@ public void testMultipleReferencesToConjunct() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(INTEGER, "a"), new SymbolReference(INTEGER, "a"))), + new Logical(AND, ImmutableList.of(new Reference(INTEGER, "a"), new Reference(INTEGER, "a"))), p.project( - Assignments.of(p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b"))))) .doesNotFire(); tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(INTEGER, "a"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(INTEGER, "a"), FALSE_LITERAL)))), + new Logical(AND, ImmutableList.of(new Reference(INTEGER, "a"), new Logical(OR, ImmutableList.of(new Reference(INTEGER, "a"), FALSE)))), p.project( - Assignments.of(p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b"))))) .doesNotFire(); } @@ -121,18 +121,18 @@ public void testInlineMultiple() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(INTEGER, "a"), new SymbolReference(INTEGER, "b"))), + new Logical(AND, ImmutableList.of(new Reference(INTEGER, "a"), new Reference(INTEGER, "b"))), p.project( Assignments.of( - p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 0L)), - p.symbol("b"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 5L))), + p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 0L)), + p.symbol("b"), new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 5L))), p.values(p.symbol("c", INTEGER))))) .matches( project( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 0L)), new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 5L)))), project( - ImmutableMap.of("c", expression(new SymbolReference(INTEGER, "c"))), + ImmutableMap.of("c", expression(new Reference(INTEGER, "c"))), values("c"))))); } @@ -141,21 +141,21 @@ public void testInlinePartially() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(INTEGER, "a"), new SymbolReference(INTEGER, "a"), new SymbolReference(INTEGER, "b"))), + new Logical(AND, ImmutableList.of(new Reference(INTEGER, "a"), new Reference(INTEGER, "a"), new Reference(INTEGER, "b"))), p.project( Assignments.of( - p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 0L)), - p.symbol("b"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 5L))), + p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 0L)), + p.symbol("b"), new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 5L))), p.values(p.symbol("c", INTEGER))))) .matches( project( - ImmutableMap.of("a", expression(new SymbolReference(INTEGER, "a")), "b", expression(TRUE_LITERAL)), + ImmutableMap.of("a", expression(new Reference(INTEGER, "a")), "b", expression(TRUE)), filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(INTEGER, "a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 5L)))), // combineConjuncts() removed duplicate conjunct `a`. The predicate is now eligible for further inlining. + new Logical(AND, ImmutableList.of(new Reference(INTEGER, "a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 5L)))), // combineConjuncts() removed duplicate conjunct `a`. The predicate is now eligible for further inlining. project( ImmutableMap.of( - "a", expression(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 0L))), - "c", expression(new SymbolReference(INTEGER, "c"))), + "a", expression(new Comparison(GREATER_THAN, new Reference(INTEGER, "c"), new Constant(INTEGER, 0L))), + "c", expression(new Reference(INTEGER, "c"))), values("c"))))); } @@ -165,18 +165,18 @@ public void testTrivialProjection() // identity projection tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new SymbolReference(INTEGER, "a"), + new Reference(INTEGER, "a"), p.project( - Assignments.of(p.symbol("a"), new SymbolReference(INTEGER, "a")), + Assignments.of(p.symbol("a"), new Reference(INTEGER, "a")), p.values(p.symbol("a"))))) .doesNotFire(); // renaming projection tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new SymbolReference(INTEGER, "a"), + new Reference(INTEGER, "a"), p.project( - Assignments.of(p.symbol("a"), new SymbolReference(INTEGER, "b")), + Assignments.of(p.symbol("a"), new Reference(INTEGER, "b")), p.values(p.symbol("b"))))) .doesNotFire(); } @@ -186,9 +186,9 @@ public void testCorrelationSymbol() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new SymbolReference(INTEGER, "corr"), + new Reference(INTEGER, "corr"), p.project( - Assignments.of(p.symbol("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b"))))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java index cf2cf5c32b4a..c689105f361f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java @@ -20,10 +20,10 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -38,9 +38,9 @@ import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -62,41 +62,41 @@ public void test() .on(p -> p.project( Assignments.builder() - .put(p.symbol("identity"), new SymbolReference(BIGINT, "symbol")) // identity - .put(p.symbol("multi_complex_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "complex"), new Constant(INTEGER, 1L))) // complex expression referenced multiple times - .put(p.symbol("multi_complex_2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "complex"), new Constant(INTEGER, 2L))) // complex expression referenced multiple times - .put(p.symbol("multi_literal_1"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "literal"), new Constant(INTEGER, 1L))) // literal referenced multiple times - .put(p.symbol("multi_literal_2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "literal"), new Constant(INTEGER, 2L))) // literal referenced multiple times - .put(p.symbol("single_complex"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "complex_2"), new Constant(INTEGER, 2L))) // complex expression reference only once - .put(p.symbol("msg_xx"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "z"), new Constant(INTEGER, 1L))) - .put(p.symbol("multi_symbol_reference"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "v"), new SymbolReference(INTEGER, "v"))) + .put(p.symbol("identity"), new Reference(BIGINT, "symbol")) // identity + .put(p.symbol("multi_complex_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex"), new Constant(INTEGER, 1L))) // complex expression referenced multiple times + .put(p.symbol("multi_complex_2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L))) // complex expression referenced multiple times + .put(p.symbol("multi_literal_1"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "literal"), new Constant(INTEGER, 1L))) // literal referenced multiple times + .put(p.symbol("multi_literal_2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "literal"), new Constant(INTEGER, 2L))) // literal referenced multiple times + .put(p.symbol("single_complex"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex_2"), new Constant(INTEGER, 2L))) // complex expression reference only once + .put(p.symbol("msg_xx"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "z"), new Constant(INTEGER, 1L))) + .put(p.symbol("multi_symbol_reference"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "v"), new Reference(INTEGER, "v"))) .build(), p.project(Assignments.builder() - .put(p.symbol("symbol"), new SymbolReference(BIGINT, "x")) - .put(p.symbol("complex"), new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 2L))) + .put(p.symbol("symbol"), new Reference(BIGINT, "x")) + .put(p.symbol("complex"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "x"), new Constant(INTEGER, 2L))) .put(p.symbol("literal"), new Constant(INTEGER, 1L)) - .put(p.symbol("complex_2"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 1L))) - .put(p.symbol("z"), new SubscriptExpression(VARCHAR, new SymbolReference(MSG_TYPE, "msg"), new Constant(INTEGER, 1L))) - .put(p.symbol("v"), new SymbolReference(BIGINT, "x")) + .put(p.symbol("complex_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))) + .put(p.symbol("z"), new Subscript(VARCHAR, new Reference(MSG_TYPE, "msg"), new Constant(INTEGER, 1L))) + .put(p.symbol("v"), new Reference(BIGINT, "x")) .build(), p.values(p.symbol("x"), p.symbol("msg", MSG_TYPE))))) .matches( project( ImmutableMap.builder() - .put("out1", PlanMatchPattern.expression(new SymbolReference(BIGINT, "x"))) - .put("out2", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 1L)))) - .put("out3", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "y"), new Constant(INTEGER, 2L)))) - .put("out4", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)))) - .put("out5", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) - .put("out6", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))) - .put("out8", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "z"), new Constant(INTEGER, 1L)))) - .put("out10", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "x"), new SymbolReference(INTEGER, "x")))) + .put("out1", PlanMatchPattern.expression(new Reference(BIGINT, "x"))) + .put("out2", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 1L)))) + .put("out3", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "y"), new Constant(INTEGER, 2L)))) + .put("out4", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)))) + .put("out5", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)))) + .put("out6", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))) + .put("out8", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "z"), new Constant(INTEGER, 1L)))) + .put("out10", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "x"), new Reference(INTEGER, "x")))) .buildOrThrow(), project( ImmutableMap.of( - "x", PlanMatchPattern.expression(new SymbolReference(BIGINT, "x")), - "y", PlanMatchPattern.expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 2L))), - "z", PlanMatchPattern.expression(new SubscriptExpression(VARCHAR, new SymbolReference(MSG_TYPE, "msg"), new Constant(INTEGER, 1L)))), + "x", PlanMatchPattern.expression(new Reference(BIGINT, "x")), + "y", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "x"), new Constant(INTEGER, 2L))), + "z", PlanMatchPattern.expression(new Subscript(VARCHAR, new Reference(MSG_TYPE, "msg"), new Constant(INTEGER, 1L)))), values(ImmutableMap.of("x", 0, "msg", 1))))); } @@ -113,8 +113,8 @@ public void testInlineEffectivelyLiteral() p.project( Assignments.builder() // Use the literal-like expression multiple times. Single-use expression may be inlined regardless of whether it's a literal - .put(p.symbol("decimal_multiplication"), new ArithmeticBinaryExpression(MULTIPLY_DECIMAL_8_4, MULTIPLY, new SymbolReference(createDecimalType(8, 4), "decimal_literal"), new SymbolReference(createDecimalType(8, 4), "decimal_literal"))) - .put(p.symbol("decimal_addition"), new ArithmeticBinaryExpression(ADD_DECIMAL_8_4, ADD, new SymbolReference(createDecimalType(8, 4), "decimal_literal"), new SymbolReference(createDecimalType(8, 4), "decimal_literal"))) + .put(p.symbol("decimal_multiplication"), new Arithmetic(MULTIPLY_DECIMAL_8_4, MULTIPLY, new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal"))) + .put(p.symbol("decimal_addition"), new Arithmetic(ADD_DECIMAL_8_4, ADD, new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal"))) .build(), p.project(Assignments.builder() .put(p.symbol("decimal_literal", createDecimalType(8, 4)), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))) @@ -123,8 +123,8 @@ public void testInlineEffectivelyLiteral() .matches( project( Map.of( - "decimal_multiplication", PlanMatchPattern.expression(new ArithmeticBinaryExpression(MULTIPLY_DECIMAL_8_4, MULTIPLY, new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))))), - "decimal_addition", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_DECIMAL_8_4, ADD, new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))))), + "decimal_multiplication", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_DECIMAL_8_4, MULTIPLY, new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))))), + "decimal_addition", PlanMatchPattern.expression(new Arithmetic(ADD_DECIMAL_8_4, ADD, new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5"))), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))))), values(Map.of("x", 0)))); } @@ -135,15 +135,15 @@ public void testEliminatesIdentityProjection() .on(p -> p.project( Assignments.builder() - .put(p.symbol("single_complex", INTEGER), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "complex"), new Constant(INTEGER, 2L))) // complex expression referenced only once + .put(p.symbol("single_complex", INTEGER), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L))) // complex expression referenced only once .build(), p.project(Assignments.builder() - .put(p.symbol("complex", INTEGER), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 1L))) + .put(p.symbol("complex", INTEGER), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))) .build(), p.values(p.symbol("x", INTEGER))))) .matches( project( - ImmutableMap.of("out1", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "x"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), + ImmutableMap.of("out1", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), values("x"))); } @@ -154,7 +154,7 @@ public void testIdentityProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.of(p.symbol("output"), new SymbolReference(BIGINT, "value")), + Assignments.of(p.symbol("output"), new Reference(BIGINT, "value")), p.project( Assignments.identity(p.symbol("value")), p.values(p.symbol("value"))))) @@ -170,7 +170,7 @@ public void testIdentityProjections() p.values(p.symbol("x"), p.symbol("y"))))) .matches( project( - ImmutableMap.of("x", PlanMatchPattern.expression(new SymbolReference(BIGINT, "x"))), + ImmutableMap.of("x", PlanMatchPattern.expression(new Reference(BIGINT, "x"))), values(ImmutableMap.of("x", 0, "y", 1)))); } @@ -195,7 +195,7 @@ public void testSubqueryProjections() p.project( Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value_1")), p.project( - Assignments.of(p.symbol("value_1"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "value"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("value_1"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "value"), new Constant(INTEGER, 1L))), p.values(p.symbol("value"))))) .matches( project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java index 611b117663ab..54766e02970a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -46,7 +46,7 @@ import java.util.Optional; import static io.airlift.testing.Closeables.closeAllRuntimeException; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.iterative.Lookup.noLookup; import static io.trino.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -100,7 +100,7 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() Symbol b1 = p.symbol("B1"); MultiJoinNode multiJoinNode = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(p.values(a1), p.values(b1))), - TRUE_LITERAL, + TRUE, ImmutableList.of(a1, b1), false); JoinEnumerator joinEnumerator = new JoinEnumerator( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java index 8efe0f3349b1..c4d8a694ae91 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -18,11 +18,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Negation; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -53,12 +53,12 @@ import static io.trino.cost.StatsAndCosts.empty; import static io.trino.metadata.AbstractMockMetadata.dummyMetadata; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -166,7 +166,7 @@ public void testPushesProjectionsThroughJoin() JoinNode joinNode = p.join( INNER, p.project( - Assignments.of(d, new ArithmeticNegation(a.toSymbolReference())), + Assignments.of(d, new Negation(a.toSymbolReference())), p.join( INNER, valuesA, @@ -210,7 +210,7 @@ public void testDoesNotPushStraddlingProjection() JoinNode joinNode = p.join( INNER, p.project( - Assignments.of(d, new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, a.toSymbolReference(), b.toSymbolReference())), + Assignments.of(d, new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, a.toSymbolReference(), b.toSymbolReference())), p.join( INNER, valuesA, @@ -285,12 +285,12 @@ public void testCombinesCriteriaAndFilters() ValuesNode valuesB = p.values(b1, b2); ValuesNode valuesC = p.values(c1, c2); Expression bcFilter = and( - new ComparisonExpression(GREATER_THAN, c2.toSymbolReference(), new Constant(INTEGER, 0L)), - new ComparisonExpression(NOT_EQUAL, c2.toSymbolReference(), new Constant(INTEGER, 7L)), - new ComparisonExpression(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference())); - ComparisonExpression abcFilter = new ComparisonExpression( + new Comparison(GREATER_THAN, c2.toSymbolReference(), new Constant(INTEGER, 0L)), + new Comparison(NOT_EQUAL, c2.toSymbolReference(), new Constant(INTEGER, 7L)), + new Comparison(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference())); + Comparison abcFilter = new Comparison( LESS_THAN, - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, a1.toSymbolReference(), c1.toSymbolReference()), + new Arithmetic(ADD_INTEGER, ADD, a1.toSymbolReference(), c1.toSymbolReference()), b1.toSymbolReference()); JoinNode joinNode = p.join( INNER, @@ -309,7 +309,7 @@ public void testCombinesCriteriaAndFilters() Optional.of(abcFilter)); MultiJoinNode expected = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(valuesA, valuesB, valuesC)), - and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), bcFilter, abcFilter), + and(new Comparison(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new Comparison(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), bcFilter, abcFilter), ImmutableList.of(a1, b1, b2, c1, c2), false); assertThat(toMultiJoinNode(joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build())).isEqualTo(expected); @@ -429,9 +429,9 @@ public void testMoreThanJoinLimit() assertThat(toMultiJoinNode(joinNode, noLookup(), planNodeIdAllocator, 2, false, testSessionBuilder().build())).isEqualTo(expected); } - private ComparisonExpression createEqualsExpression(Symbol left, Symbol right) + private Comparison createEqualsExpression(Symbol left, Symbol right) { - return new ComparisonExpression(EQUAL, left.toSymbolReference(), right.toSymbolReference()); + return new Comparison(EQUAL, left.toSymbolReference(), right.toSymbolReference()); } private EquiJoinClause equiJoinClause(Symbol symbol1, Symbol symbol2) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java index bad42c963ae5..b1c26ae9d606 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java @@ -17,17 +17,17 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.BindExpression; -import io.trino.sql.ir.LambdaExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Bind; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter.rewrite; import static org.assertj.core.api.Assertions.assertThat; @@ -43,12 +43,12 @@ public void testRewriteBasicLambda() assertThat( rewrite( - new LambdaExpression(ImmutableList.of(new Symbol(INTEGER, "x")), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new SymbolReference(INTEGER, "x"))), + new Lambda(ImmutableList.of(new Symbol(INTEGER, "x")), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Reference(INTEGER, "x"))), allocator)) - .isEqualTo(new BindExpression( - ImmutableList.of(new SymbolReference(INTEGER, "a")), - new LambdaExpression( + .isEqualTo(new Bind( + ImmutableList.of(new Reference(INTEGER, "a")), + new Lambda( ImmutableList.of(new Symbol(INTEGER, "a_0"), new Symbol(INTEGER, "x")), - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a_0"), new SymbolReference(INTEGER, "x"))))); + new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a_0"), new Reference(INTEGER, "x"))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 447285d75f4b..029cfa8a335d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -18,9 +18,9 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.ExpectedValueProvider; import io.trino.sql.planner.assertions.PlanMatchPattern; @@ -39,7 +39,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -103,7 +103,7 @@ public void testIntermediateNonProjectNode() newWindowNodeSpecification(p, "a"), ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction(AVG, "a")), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), p.window( newWindowNodeSpecification(p, "a"), ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction(AVG, "a")), @@ -188,10 +188,10 @@ public void testIntermediateProjectNodes() .matches( strictProject( ImmutableMap.of( - columnAAlias, PlanMatchPattern.expression(new SymbolReference(BIGINT, columnAAlias)), - oneAlias, PlanMatchPattern.expression(new SymbolReference(BIGINT, oneAlias)), - lagOutputAlias, PlanMatchPattern.expression(new SymbolReference(DOUBLE, lagOutputAlias)), - avgOutputAlias, PlanMatchPattern.expression(new SymbolReference(DOUBLE, avgOutputAlias))), + columnAAlias, PlanMatchPattern.expression(new Reference(BIGINT, columnAAlias)), + oneAlias, PlanMatchPattern.expression(new Reference(BIGINT, oneAlias)), + lagOutputAlias, PlanMatchPattern.expression(new Reference(DOUBLE, lagOutputAlias)), + avgOutputAlias, PlanMatchPattern.expression(new Reference(DOUBLE, avgOutputAlias))), window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationA) .addFunction(lagOutputAlias, windowFunction(LAG.getSignature().getName().getFunctionName(), ImmutableList.of(columnAAlias, oneAlias), DEFAULT_FRAME)) @@ -199,12 +199,12 @@ public void testIntermediateProjectNodes() strictProject( ImmutableMap.of( oneAlias, PlanMatchPattern.expression(new Cast(new Constant(INTEGER, 1L), BIGINT)), - columnAAlias, PlanMatchPattern.expression(new SymbolReference(BIGINT, columnAAlias)), - unusedAlias, PlanMatchPattern.expression(new SymbolReference(BIGINT, unusedAlias))), + columnAAlias, PlanMatchPattern.expression(new Reference(BIGINT, columnAAlias)), + unusedAlias, PlanMatchPattern.expression(new Reference(BIGINT, unusedAlias))), strictProject( ImmutableMap.of( - columnAAlias, PlanMatchPattern.expression(new SymbolReference(BIGINT, columnAAlias)), - unusedAlias, PlanMatchPattern.expression(new SymbolReference(BIGINT, unusedAlias))), + columnAAlias, PlanMatchPattern.expression(new Reference(BIGINT, columnAAlias)), + unusedAlias, PlanMatchPattern.expression(new Reference(BIGINT, unusedAlias))), values(columnAAlias, unusedAlias)))))); } @@ -234,7 +234,7 @@ private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolv return new WindowNode.Function( resolvedFunction, Arrays.stream(symbols) - .map(name -> new SymbolReference(DOUBLE, name)) + .map(name -> new Reference(DOUBLE, name)) .collect(Collectors.toList()), DEFAULT_FRAME, false); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java index 8c97f485c724..3c247190e7e8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java @@ -15,17 +15,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -38,12 +38,12 @@ public void test() tester().assertThat(new MergeFilters()) .on(p -> p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 44L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 44L)), p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 42L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 42L)), p.values(p.symbol("a"), p.symbol("b"))))) .matches(filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 42L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 44L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 42L)), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 44L)))), values(ImmutableMap.of("a", 0, "b", 1)))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java index e0b33b0ab37d..2dc7742d528b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -52,7 +52,7 @@ public void testMergeLimitOverProjectWithSort() }) .matches( project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), topN( 1, ImmutableList.of(sort("a", ASCENDING, FIRST)), @@ -96,7 +96,7 @@ public void testLimitWithPreSortedInputs() }) .matches( project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), topN( 1, ImmutableList.of(sort("a", ASCENDING, FIRST)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitWithDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitWithDistinct.java index 0c2ce750e2be..79ab7a6eaa3a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitWithDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitWithDistinct.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -51,7 +51,7 @@ public void testDoesNotFire() p.limit( 1, p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(p.values(p.symbol("foo")))))) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java index 50dbffc85895..dd18c6b6e0ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java @@ -19,11 +19,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.ExpressionMatcher; @@ -48,11 +48,11 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.patternRecognition; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -91,22 +91,22 @@ public void testSpecificationsDoNotMatch() tester().assertThat(new MergePatternRecognitionNodesWithoutProject()) .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), FALSE_LITERAL) + .addVariableDefinition(new IrLabel("X"), FALSE) .source(p.values(p.symbol("a"))))))) .doesNotFire(); tester().assertThat(new MergePatternRecognitionNodesWithProject()) .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( Assignments.identity(p.symbol("a")), p.patternRecognition(childBuilder -> childBuilder .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), FALSE_LITERAL) + .addVariableDefinition(new IrLabel("X"), FALSE) .source(p.values(p.symbol("a")))))))) .doesNotFire(); @@ -117,12 +117,12 @@ public void testSpecificationsDoNotMatch() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new FunctionCall(count, ImmutableList.of(new SymbolReference(INTEGER, "a"))), new Constant(INTEGER, 5L))) + new Comparison(GREATER_THAN, new Call(count, ImmutableList.of(new Reference(INTEGER, "a"))), new Constant(INTEGER, 5L))) .source(p.patternRecognition(childBuilder -> childBuilder .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new FunctionCall(count, ImmutableList.of(new SymbolReference(INTEGER, "b"))), new Constant(INTEGER, 5L))) + new Comparison(GREATER_THAN, new Call(count, ImmutableList.of(new Reference(INTEGER, "b"))), new Constant(INTEGER, 5L))) .source(p.values(p.symbol("a"), p.symbol("b"))))))) .doesNotFire(); } @@ -137,21 +137,21 @@ public void testParentDependsOnSourceCreatedOutputs() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("dependent"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "measure")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new MatchNumberValuePointer())) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))))) .doesNotFire(); @@ -160,20 +160,20 @@ public void testParentDependsOnSourceCreatedOutputs() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("dependent"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "function")))) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder .addWindowFunction(p.symbol("function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))))) .doesNotFire(); @@ -184,13 +184,13 @@ public void testParentDependsOnSourceCreatedOutputs() .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder .addWindowFunction(p.symbol("function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))))) .doesNotFire(); @@ -201,16 +201,16 @@ public void testParentDependsOnSourceCreatedOutputs() .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new MatchNumberValuePointer())) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))))) .doesNotFire(); } @@ -223,23 +223,23 @@ public void testParentDependsOnSourceCreatedOutputsWithProject() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("dependent"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "measure")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( Assignments.identity(p.symbol("measure")), p.patternRecognition(childBuilder -> childBuilder .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new MatchNumberValuePointer())) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a")))))))) .doesNotFire(); @@ -248,23 +248,23 @@ public void testParentDependsOnSourceCreatedOutputsWithProject() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("dependent"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "renamed")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( - Assignments.of(p.symbol("renamed"), new SymbolReference(BIGINT, "measure")), + Assignments.of(p.symbol("renamed"), new Reference(BIGINT, "measure")), p.patternRecognition(childBuilder -> childBuilder .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new MatchNumberValuePointer())) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a")))))))) .doesNotFire(); @@ -273,23 +273,23 @@ public void testParentDependsOnSourceCreatedOutputsWithProject() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("dependent"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "projected")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( - Assignments.of(p.symbol("projected"), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "measure"))), + Assignments.of(p.symbol("projected"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "measure"))), p.patternRecognition(childBuilder -> childBuilder .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new MatchNumberValuePointer())) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a")))))))) .doesNotFire(); } @@ -305,7 +305,7 @@ public void testMergeWithoutProject() .orderBy(new OrderingScheme(ImmutableList.of(p.symbol("d")), ImmutableMap.of(p.symbol("d"), ASC_NULLS_LAST))) .addMeasure( p.symbol("parent_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "b")))) @@ -316,14 +316,14 @@ public void testMergeWithoutProject() .seek() .addSubset(new IrLabel("U"), ImmutableSet.of(new IrLabel("X"))) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> { childBuilder .partitionBy(ImmutableList.of(p.symbol("c"))) .orderBy(new OrderingScheme(ImmutableList.of(p.symbol("d")), ImmutableMap.of(p.symbol("d"), ASC_NULLS_LAST))) .addMeasure( p.symbol("child_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "a")))) @@ -334,7 +334,7 @@ public void testMergeWithoutProject() .seek() .addSubset(new IrLabel("U"), ImmutableSet.of(new IrLabel("X"))) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b"), p.symbol("c"), p.symbol("d"))); })))) .matches( @@ -342,14 +342,14 @@ public void testMergeWithoutProject() .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))) .addMeasure( "parent_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "b"))), BIGINT) .addMeasure( "child_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "a"))), @@ -362,7 +362,7 @@ public void testMergeWithoutProject() .seek() .addSubset(new IrLabel("U"), ImmutableSet.of(new IrLabel("X"))) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), values("a", "b", "c", "d"))); } @@ -376,51 +376,51 @@ public void testMergeWithoutProjectAndPruneOutputs() .partitionBy(ImmutableList.of(p.symbol("c"))) .addMeasure( p.symbol("parent_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "b")))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> { childBuilder .partitionBy(ImmutableList.of(p.symbol("c"))) .addMeasure( p.symbol("child_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "a")))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b"), p.symbol("c"))); })))) .matches( project( ImmutableMap.of( - "c", expression(new SymbolReference(BIGINT, "c")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure"))), + "c", expression(new Reference(BIGINT, "c")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure"))), patternRecognition(builder -> builder .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) .addMeasure( "parent_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "b"))), BIGINT) .addMeasure( "child_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "a"))), BIGINT) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), values("a", "b", "c")))); } @@ -433,61 +433,61 @@ public void testMergeWithProject() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("parent_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( Assignments.of( - p.symbol("a"), new SymbolReference(BIGINT, "a"), - p.symbol("expression"), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), + p.symbol("a"), new Reference(BIGINT, "a"), + p.symbol("expression"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), p.patternRecognition(childBuilder -> { childBuilder .addMeasure( p.symbol("child_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b"))); }))))) .matches( project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure")), - "expression", expression(new SymbolReference(BIGINT, "expression"))), + "a", expression(new Reference(BIGINT, "a")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure")), + "expression", expression(new Reference(BIGINT, "expression"))), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure")), - "child_measure", expression(new SymbolReference(BIGINT, "child_measure")), - "expression", expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure")), + "child_measure", expression(new Reference(BIGINT, "child_measure")), + "expression", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), patternRecognition(builder -> builder .addMeasure( "parent_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a"))), BIGINT) .addMeasure( "child_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b"))), BIGINT) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), values("a", "b"))))); // project is based on symbols created by the child node @@ -496,61 +496,61 @@ public void testMergeWithProject() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("parent_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( Assignments.of( - p.symbol("a"), new SymbolReference(BIGINT, "a"), - p.symbol("expression"), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")), new SymbolReference(BIGINT, "child_measure"))), + p.symbol("a"), new Reference(BIGINT, "a"), + p.symbol("expression"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), new Reference(BIGINT, "child_measure"))), p.patternRecognition(childBuilder -> { childBuilder .addMeasure( p.symbol("child_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b"))); }))))) .matches( project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure")), - "expression", expression(new SymbolReference(BIGINT, "expression"))), + "a", expression(new Reference(BIGINT, "a")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure")), + "expression", expression(new Reference(BIGINT, "expression"))), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure")), - "child_measure", expression(new SymbolReference(BIGINT, "child_measure")), - "expression", expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")), new SymbolReference(BIGINT, "child_measure")))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure")), + "child_measure", expression(new Reference(BIGINT, "child_measure")), + "expression", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), new Reference(BIGINT, "child_measure")))), patternRecognition(builder -> builder .addMeasure( "parent_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a"))), BIGINT) .addMeasure( "child_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b"))), BIGINT) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), values("a", "b"))))); } @@ -564,71 +564,71 @@ public void testMergeWithParentDependingOnProject() .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .addMeasure( p.symbol("parent_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "expression_1")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( Assignments.builder() - .put(p.symbol("a"), new SymbolReference(BIGINT, "a")) - .put(p.symbol("expression_1"), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))) - .put(p.symbol("expression_2"), new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))) + .put(p.symbol("a"), new Reference(BIGINT, "a")) + .put(p.symbol("expression_1"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))) + .put(p.symbol("expression_2"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))) .build(), p.patternRecognition(childBuilder -> { childBuilder .addMeasure( p.symbol("child_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b")))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b"))); }))))) .matches( project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure")), - "expression_1", expression(new SymbolReference(BIGINT, "expression_1")), - "expression_2", expression(new SymbolReference(BIGINT, "expression_2"))), + "a", expression(new Reference(BIGINT, "a")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure")), + "expression_1", expression(new Reference(BIGINT, "expression_1")), + "expression_2", expression(new Reference(BIGINT, "expression_2"))), project( ImmutableMap.builder() - .put("a", expression(new SymbolReference(BIGINT, "a"))) - .put("b", expression(new SymbolReference(BIGINT, "b"))) - .put("parent_measure", expression(new SymbolReference(BIGINT, "parent_measure"))) - .put("child_measure", expression(new SymbolReference(BIGINT, "child_measure"))) - .put("expression_1", expression(new SymbolReference(BIGINT, "expression_1"))) - .put("expression_2", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))) + .put("a", expression(new Reference(BIGINT, "a"))) + .put("b", expression(new Reference(BIGINT, "b"))) + .put("parent_measure", expression(new Reference(BIGINT, "parent_measure"))) + .put("child_measure", expression(new Reference(BIGINT, "child_measure"))) + .put("expression_1", expression(new Reference(BIGINT, "expression_1"))) + .put("expression_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))) .buildOrThrow(), patternRecognition(builder -> builder .addMeasure( "parent_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "expression_1"))), BIGINT) .addMeasure( "child_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b"))), BIGINT) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "expression_1", expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "expression_1", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), values("a", "b")))))); } @@ -644,69 +644,69 @@ public void testOneRowPerMatchMergeWithParentDependingOnProject() .partitionBy(ImmutableList.of(p.symbol("a"))) .addMeasure( p.symbol("parent_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "expression_1")))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.project( Assignments.builder() - .put(p.symbol("a"), new SymbolReference(BIGINT, "a")) - .put(p.symbol("child_measure"), new SymbolReference(BIGINT, "child_measure")) - .put(p.symbol("expression_1"), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "a"))) - .put(p.symbol("expression_2"), new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "a"))) + .put(p.symbol("a"), new Reference(BIGINT, "a")) + .put(p.symbol("child_measure"), new Reference(BIGINT, "child_measure")) + .put(p.symbol("expression_1"), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "a"))) + .put(p.symbol("expression_2"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "a"))) .build(), p.patternRecognition(childBuilder -> { childBuilder .partitionBy(ImmutableList.of(p.symbol("a"))) .addMeasure( p.symbol("child_measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b")))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b"))); }))))) .matches( project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure"))), + "a", expression(new Reference(BIGINT, "a")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure"))), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "parent_measure", expression(new SymbolReference(BIGINT, "parent_measure")), - "child_measure", expression(new SymbolReference(BIGINT, "child_measure")), - "expression_2", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "a")))), + "a", expression(new Reference(BIGINT, "a")), + "parent_measure", expression(new Reference(BIGINT, "parent_measure")), + "child_measure", expression(new Reference(BIGINT, "child_measure")), + "expression_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "a")))), patternRecognition(builder -> builder .specification(specification(ImmutableList.of("a"), ImmutableList.of(), ImmutableMap.of())) .addMeasure( "parent_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "expression_1"))), BIGINT) .addMeasure( "child_measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(UNKNOWN, "b"))), BIGINT) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "expression_1", PlanMatchPattern.expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "a")))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "expression_1", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Reference(BIGINT, "a")))), values("a", "b")))))); } @@ -719,22 +719,22 @@ public void testMergeWithAggregation() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "c"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "c"), new Constant(BIGINT, 5L)), ImmutableMap.of(new Symbol(BIGINT, "c"), new AggregationValuePointer( count, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new SymbolReference(BIGINT, "a")), + ImmutableList.of(new Reference(BIGINT, "a")), Optional.empty(), Optional.empty()))) .source(p.patternRecognition(childBuilder -> childBuilder .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "c"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "c"), new Constant(BIGINT, 5L)), ImmutableMap.of(new Symbol(BIGINT, "c"), new AggregationValuePointer( count, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new SymbolReference(BIGINT, "a")), + ImmutableList.of(new Reference(BIGINT, "a")), Optional.empty(), Optional.empty()))) .source(p.values(p.symbol("a", BIGINT))))))) @@ -743,11 +743,11 @@ public void testMergeWithAggregation() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "c"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "c"), new Constant(BIGINT, 5L)), ImmutableMap.of("c", new AggregationValuePointer( count, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new SymbolReference(BIGINT, "a")), + ImmutableList.of(new Reference(BIGINT, "a")), Optional.empty(), Optional.empty()))), values("a"))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java index a8dfd6db456b..92756a2fa515 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java @@ -19,14 +19,14 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -38,9 +38,9 @@ import static io.trino.spi.type.CharType.createCharType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -77,8 +77,8 @@ public void testProjectWithoutOutputSymbols() p.valuesOfExpressions( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( - new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE_LITERAL)), - new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("y")), FALSE_LITERAL)))))) + new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE)), + new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("y")), FALSE)))))) .matches(values(2)); // ValuesNode has no output symbols and two rows @@ -119,21 +119,21 @@ public void testValuesWithoutOutputSymbols() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE_LITERAL), + Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE), p.values( ImmutableList.of(), ImmutableList.of(ImmutableList.of(), ImmutableList.of())))) .matches(values( ImmutableList.of("a", "b"), ImmutableList.of( - ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE_LITERAL), - ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE_LITERAL)))); + ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE), + ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE)))); // ValuesNode has no rows tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE_LITERAL), + Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE), p.values( ImmutableList.of(), ImmutableList.of()))) @@ -143,13 +143,13 @@ public void testValuesWithoutOutputSymbols() @Test public void testNonDeterministicValues() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("rand", DOUBLE), new SymbolReference(DOUBLE, "rand")), + Assignments.of(p.symbol("rand", DOUBLE), new Reference(DOUBLE, "rand")), p.valuesOfExpressions( ImmutableList.of(p.symbol("rand", DOUBLE)), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) @@ -161,54 +161,54 @@ public void testNonDeterministicValues() // ValuesNode has multiple rows tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("output", DOUBLE), new SymbolReference(DOUBLE, "value")), + Assignments.of(p.symbol("output", DOUBLE), new Reference(DOUBLE, "value")), p.valuesOfExpressions( ImmutableList.of(p.symbol("value", DOUBLE)), ImmutableList.of( new Row(ImmutableList.of(new Constant(DOUBLE, null))), new Row(ImmutableList.of(randomFunction)), - new Row(ImmutableList.of(new ArithmeticNegation(randomFunction))))))) + new Row(ImmutableList.of(new Negation(randomFunction))))))) .matches( values( ImmutableList.of("output"), ImmutableList.of( ImmutableList.of(new Constant(DOUBLE, null)), ImmutableList.of(randomFunction), - ImmutableList.of(new ArithmeticNegation(randomFunction))))); + ImmutableList.of(new Negation(randomFunction))))); // ValuesNode has multiple non-deterministic outputs tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of( - p.symbol("x"), new ArithmeticNegation(new SymbolReference(DOUBLE, "a")), - p.symbol("y"), new SymbolReference(DOUBLE, "b")), + p.symbol("x"), new Negation(new Reference(DOUBLE, "a")), + p.symbol("y"), new Reference(DOUBLE, "b")), p.valuesOfExpressions( ImmutableList.of(p.symbol("a", DOUBLE), p.symbol("b", DOUBLE)), ImmutableList.of( new Row(ImmutableList.of(new Constant(DOUBLE, 1e0), randomFunction)), new Row(ImmutableList.of(randomFunction, new Constant(DOUBLE, null))), - new Row(ImmutableList.of(new ArithmeticNegation(randomFunction), new Constant(DOUBLE, null))))))) + new Row(ImmutableList.of(new Negation(randomFunction), new Constant(DOUBLE, null))))))) .matches( values( ImmutableList.of("x", "y"), ImmutableList.of( - ImmutableList.of(new ArithmeticNegation(new Constant(DOUBLE, 1e0)), randomFunction), - ImmutableList.of(new ArithmeticNegation(randomFunction), new Constant(DOUBLE, null)), - ImmutableList.of(new ArithmeticNegation(new ArithmeticNegation(randomFunction)), new Constant(DOUBLE, null))))); + ImmutableList.of(new Negation(new Constant(DOUBLE, 1e0)), randomFunction), + ImmutableList.of(new Negation(randomFunction), new Constant(DOUBLE, null)), + ImmutableList.of(new Negation(new Negation(randomFunction)), new Constant(DOUBLE, null))))); } @Test public void testDoNotFireOnNonDeterministicValues() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of( - p.symbol("x"), new SymbolReference(DOUBLE, "rand"), - p.symbol("y"), new SymbolReference(DOUBLE, "rand")), + p.symbol("x"), new Reference(DOUBLE, "rand"), + p.symbol("y"), new Reference(DOUBLE, "rand")), p.valuesOfExpressions( ImmutableList.of(p.symbol("rand")), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) @@ -216,7 +216,7 @@ public void testDoNotFireOnNonDeterministicValues() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new ArithmeticBinaryExpression(ADD_DOUBLE, ADD, new SymbolReference(DOUBLE, "rand"), new SymbolReference(DOUBLE, "rand"))), + Assignments.of(p.symbol("x"), new Arithmetic(ADD_DOUBLE, ADD, new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand"))), p.valuesOfExpressions( ImmutableList.of(p.symbol("rand")), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) @@ -229,20 +229,20 @@ public void testCorrelation() // correlation symbol in projection (note: the resulting plan is not yet supported in execution) tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x", INTEGER), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new SymbolReference(INTEGER, "corr"))), + Assignments.of(p.symbol("x", INTEGER), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Reference(INTEGER, "corr"))), p.valuesOfExpressions( ImmutableList.of(p.symbol("a", INTEGER)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L))))))) - .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "corr")))))); + .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Reference(INTEGER, "corr")))))); // correlation symbol in values (note: the resulting plan is not yet supported in execution) tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new SymbolReference(BIGINT, "a")), + Assignments.of(p.symbol("x"), new Reference(BIGINT, "a")), p.valuesOfExpressions( ImmutableList.of(p.symbol("a")), - ImmutableList.of(new Row(ImmutableList.of(new SymbolReference(BIGINT, "corr"))))))) - .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new SymbolReference(BIGINT, "corr"))))); + ImmutableList.of(new Row(ImmutableList.of(new Reference(BIGINT, "corr"))))))) + .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new Reference(BIGINT, "corr"))))); // correlation symbol is not present in the resulting expression tester().assertThat(new MergeProjectWithValues()) @@ -250,14 +250,14 @@ public void testCorrelation() Assignments.of(p.symbol("x"), new Constant(INTEGER, 1L)), p.valuesOfExpressions( ImmutableList.of(p.symbol("a")), - ImmutableList.of(new Row(ImmutableList.of(new SymbolReference(INTEGER, "corr"))))))) + ImmutableList.of(new Row(ImmutableList.of(new Reference(INTEGER, "corr"))))))) .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L))))); } @Test public void testFailingExpression() { - FunctionCall failFunction = failFunction(tester().getMetadata(), GENERIC_USER_ERROR, "message"); + Call failFunction = failFunction(tester().getMetadata(), GENERIC_USER_ERROR, "message"); tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( @@ -282,23 +282,23 @@ public void testMergeProjectWithValues() Assignments.Builder assignments = Assignments.builder(); assignments.putIdentity(a); // identity assignment assignments.put(d, b.toSymbolReference()); // renaming assignment - assignments.put(e, new IsNullPredicate(a.toSymbolReference())); // expression involving input symbol + assignments.put(e, new IsNull(a.toSymbolReference())); // expression involving input symbol assignments.put(f, new Constant(INTEGER, 1L)); // constant expression return p.project( assignments.build(), p.valuesOfExpressions( ImmutableList.of(a, b, c), ImmutableList.of( - new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE_LITERAL, new Constant(INTEGER, 1L))), - new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("y")), FALSE_LITERAL, new Constant(INTEGER, 2L))), - new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("z")), TRUE_LITERAL, new Constant(INTEGER, 3L)))))); + new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE, new Constant(INTEGER, 1L))), + new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("y")), FALSE, new Constant(INTEGER, 2L))), + new Row(ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("z")), TRUE, new Constant(INTEGER, 3L)))))); }) .matches(values( ImmutableList.of("a", "d", "e", "f"), ImmutableList.of( - ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE_LITERAL, new IsNullPredicate(new Constant(createCharType(1), Slices.utf8Slice("x"))), new Constant(INTEGER, 1L)), - ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("y")), FALSE_LITERAL, new IsNullPredicate(new Constant(createCharType(1), Slices.utf8Slice("y"))), new Constant(INTEGER, 1L)), - ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("z")), TRUE_LITERAL, new IsNullPredicate(new Constant(createCharType(1), Slices.utf8Slice("z"))), new Constant(INTEGER, 1L))))); + ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE, new IsNull(new Constant(createCharType(1), Slices.utf8Slice("x"))), new Constant(INTEGER, 1L)), + ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("y")), FALSE, new IsNull(new Constant(createCharType(1), Slices.utf8Slice("y"))), new Constant(INTEGER, 1L)), + ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("z")), TRUE, new IsNull(new Constant(createCharType(1), Slices.utf8Slice("z"))), new Constant(INTEGER, 1L))))); // ValuesNode has no rows tester().assertThat(new MergeProjectWithValues()) @@ -312,7 +312,7 @@ public void testMergeProjectWithValues() Assignments.Builder assignments = Assignments.builder(); assignments.putIdentity(a); // identity assignment assignments.put(d, b.toSymbolReference()); // renaming assignment - assignments.put(e, new IsNullPredicate(a.toSymbolReference())); // expression involving input symbol + assignments.put(e, new IsNull(a.toSymbolReference())); // expression involving input symbol assignments.put(f, new Constant(INTEGER, 1L)); // constant expression return p.project( assignments.build(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java index 799692659afb..603c5d2f8da5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java @@ -17,9 +17,9 @@ import com.google.common.collect.ImmutableMap; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.cost.TaskCountEstimator; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -37,7 +37,7 @@ import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.globalAggregation; @@ -59,8 +59,8 @@ public void testNoDistinct() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), @@ -74,7 +74,7 @@ public void testSingleDistinct() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), @@ -88,8 +88,8 @@ public void testMultipleAggregations() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .doesNotFire(); @@ -103,17 +103,17 @@ public void testDistinctWithFilter() .globalGrouping() .addAggregation( p.symbol("output1"), - PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) + PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) .addAggregation( p.symbol("output2"), - PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input2")), new Symbol(UNKNOWN, "filter2")), ImmutableList.of(BIGINT)) + PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2")), new Symbol(UNKNOWN, "filter2")), ImmutableList.of(BIGINT)) .source( p.project( Assignments.builder() .putIdentity(p.symbol("input1")) .putIdentity(p.symbol("input2")) - .put(p.symbol("filter1"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "input2"), new Constant(INTEGER, 0L))) - .put(p.symbol("filter2"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "input1"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter1"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter2"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input1"), new Constant(INTEGER, 0L))) .build(), p.values( p.symbol("input1"), @@ -123,15 +123,15 @@ public void testDistinctWithFilter() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) .source( p.project( Assignments.builder() .putIdentity(p.symbol("input1")) .putIdentity(p.symbol("input2")) - .put(p.symbol("filter1"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "input2"), new Constant(INTEGER, 0L))) - .put(p.symbol("filter2"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "input1"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter1"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter2"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input1"), new Constant(INTEGER, 0L))) .build(), p.values( p.symbol("input1"), @@ -145,8 +145,8 @@ public void testGlobalAggregation() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input1"), p.symbol("input2"))))) .matches(aggregation( @@ -174,8 +174,8 @@ public void testAggregationNDV() Function plan = p -> p.aggregation(builder -> builder .nodeId(aggregationNodeId) .singleGroupingSet(p.symbol("key")) - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"), p.symbol("key")))); PlanMatchPattern expectedMarkDistinct = aggregation( @@ -227,8 +227,8 @@ public void testAggregationNDV() .on(p -> p.aggregation(builder -> builder .nodeId(aggregationNodeId) .singleGroupingSet(p.symbol("key")) - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input1"), p.symbol("input2"), p.symbol("key"))))) .doesNotFire(); @@ -258,8 +258,8 @@ public void testAggregationNDV() .on(p -> p.aggregation(builder -> builder .nodeId(aggregationNodeId) .singleGroupingSet(p.symbol("key1"), p.symbol("key2")) - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"), p.symbol("key1"), p.symbol("key2"))))) .matches(aggregation( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java index 377680279348..188d772da3fe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java @@ -18,10 +18,10 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -32,7 +32,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -116,7 +116,7 @@ public void testNestedJoins() p.values(symbolA), p.project(identity(symbolB), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 10L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 10L)), p.join( INNER, p.values(symbolB), @@ -128,7 +128,7 @@ public void testNestedJoins() .left(values("A")) .right(project( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "B"), new Constant(INTEGER, 10L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "B"), new Constant(INTEGER, 10L)), join(INNER, rightJoinBuilder -> rightJoinBuilder .left(values("B")) .right(values("C"))) @@ -139,7 +139,7 @@ public void testNestedJoins() @Test public void testNondeterministicJoins() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); @@ -157,12 +157,12 @@ public void testNondeterministicJoins() INNER, p.values(symbolB), p.values(symbolC)), - new ComparisonExpression(GREATER_THAN, symbolB.toSymbolReference(), randomFunction)))); + new Comparison(GREATER_THAN, symbolB.toSymbolReference(), randomFunction)))); }) .matches( aggregation(ImmutableMap.of(), join(INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "B"), new FunctionCall(RANDOM, ImmutableList.of()))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "B"), new Call(RANDOM, ImmutableList.of()))) .left(values("A")) .right( join(INNER, rightJoinBuilder -> rightJoinBuilder @@ -175,7 +175,7 @@ public void testNondeterministicJoins() @Test public void testNondeterministicFilter() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); @@ -185,7 +185,7 @@ public void testNondeterministicFilter() Symbol symbolB = p.symbol("b"); return p.aggregation(a -> a .singleGroupingSet(symbolA) - .source(p.filter(new ComparisonExpression(GREATER_THAN, symbolB.toSymbolReference(), randomFunction), + .source(p.filter(new Comparison(GREATER_THAN, symbolB.toSymbolReference(), randomFunction), p.join( INNER, p.values(symbolA), @@ -197,7 +197,7 @@ public void testNondeterministicFilter() @Test public void testNondeterministicProjection() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java index f6acec6e8743..e433a0569cb6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java @@ -25,14 +25,14 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.Plan; import io.trino.sql.planner.assertions.AggregationFunction; @@ -60,10 +60,10 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -137,7 +137,7 @@ public void testPreAggregatesCaseAggregations() "GROUP BY (col_varchar || 'a')", anyTree( project( - ImmutableMap.of("SUM_2_CAST", expression(new Cast(new SymbolReference(BIGINT, "SUM_2"), createVarcharType(10)))), + ImmutableMap.of("SUM_2_CAST", expression(new Cast(new Reference(BIGINT, "SUM_2"), createVarcharType(10)))), aggregation( singleGroupingSet("KEY"), ImmutableMap., ExpectedValueProvider>builder() @@ -151,12 +151,12 @@ public void testPreAggregatesCaseAggregations() Optional.empty(), SINGLE, project(ImmutableMap.builder() - .put("SUM_1_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_2_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_3_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new SymbolReference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("MIN_1_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) - .put("SUM_4_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new SymbolReference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) - .put("SUM_5_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new SymbolReference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) + .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) + .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) + .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) + .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) .buildOrThrow(), aggregation( singleGroupingSet("KEY", "COL_BIGINT"), @@ -170,11 +170,11 @@ public void testPreAggregatesCaseAggregations() SINGLE, exchange( project(ImmutableMap.of( - "KEY", expression(new FunctionCall(CONCAT, ImmutableList.of(new SymbolReference(VARCHAR, "COL_VARCHAR"), new Constant(VARCHAR, Slices.utf8Slice("a"))))), - "VALUE_BIGINT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new InPredicate(new SymbolReference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_INT_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), - "VALUE_2_BIGINT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_DECIMAL_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new ArithmeticBinaryExpression(MULTIPLY_DECIMAL_10_0, MULTIPLY, new SymbolReference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2")))), BIGINT))), Optional.empty()))), + "KEY", expression(new Call(CONCAT, ImmutableList.of(new Reference(VARCHAR, "COL_VARCHAR"), new Constant(VARCHAR, Slices.utf8Slice("a"))))), + "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), + "VALUE_2_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), + "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Arithmetic(MULTIPLY_DECIMAL_10_0, MULTIPLY, new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2")))), BIGINT))), Optional.empty()))), tableScan( "t", ImmutableMap.of( @@ -197,7 +197,7 @@ public void testGlobalPreAggregatesCaseAggregations() "FROM t", anyTree( project( - ImmutableMap.of("SUM_2_CAST", expression(new Cast(new SymbolReference(BIGINT, "SUM_2"), createVarcharType(10)))), + ImmutableMap.of("SUM_2_CAST", expression(new Cast(new Reference(BIGINT, "SUM_2"), createVarcharType(10)))), aggregation( globalAggregation(), ImmutableMap., ExpectedValueProvider>builder() @@ -211,12 +211,12 @@ public void testGlobalPreAggregatesCaseAggregations() Optional.empty(), SINGLE, project(ImmutableMap.builder() - .put("SUM_1_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_2_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_3_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new SymbolReference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("MIN_1_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) - .put("SUM_4_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new SymbolReference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) - .put("SUM_5_INPUT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new SymbolReference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) + .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) + .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) + .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) + .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) .buildOrThrow(), aggregation( singleGroupingSet("COL_BIGINT"), @@ -230,10 +230,10 @@ public void testGlobalPreAggregatesCaseAggregations() SINGLE, exchange( project(ImmutableMap.of( - "VALUE_BIGINT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new InPredicate(new SymbolReference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_INT_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), - "VALUE_2_INT_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), - "VALUE_DECIMAL_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new ArithmeticBinaryExpression(MULTIPLY_DECIMAL_10_0, MULTIPLY, new SymbolReference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2")))), BIGINT))), Optional.empty()))), + "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), + "VALUE_2_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), + "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Arithmetic(MULTIPLY_DECIMAL_10_0, MULTIPLY, new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2")))), BIGINT))), Optional.empty()))), tableScan( "t", ImmutableMap.of( @@ -279,18 +279,18 @@ public void testPreAggregatesWithDefaultValues() Optional.empty(), SINGLE, project(ImmutableMap.builder() - .put("SUM_BIGINT_FINAL", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("SUM_BIGINT_FINAL_DEFAULT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new SymbolReference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_INT_CAST_FINAL", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new SymbolReference(BIGINT, "SUM_INT_CAST"))), Optional.empty()))) - .put("SUM_INT_CAST_FINAL_DEFAULT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new SymbolReference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_TINYINT_FINAL", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new SymbolReference(TINYINT, "SUM_TINYINT"))), Optional.empty()))) - .put("SUM_TINYINT_FINAL_DEFAULT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new SymbolReference(TINYINT, "SUM_TINYINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_DECIMAL_FINAL", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new SymbolReference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) - .put("SUM_DECIMAL_FINAL_DEFAULT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new SymbolReference(BIGINT, "SUM_DECIMAL"))), Optional.of(new Constant(createDecimalType(38, 1), Decimals.valueOf(new BigDecimal("0.0"))))))) - .put("SUM_LONG_DECIMAL_FINAL", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 5L)), new SymbolReference(BIGINT, "SUM_LONG_DECIMAL"))), Optional.empty()))) - .put("SUM_LONG_DECIMAL_FINAL_DEFAULT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 5L)), new SymbolReference(BIGINT, "SUM_LONG_DECIMAL"))), Optional.of(new Constant(createDecimalType(38, 18), Decimals.valueOf(new BigDecimal("0.000000000000000000"))))))) - .put("SUM_DOUBLE_FINAL", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 6L)), new SymbolReference(DOUBLE, "SUM_DOUBLE"))), Optional.empty()))) - .put("SUM_DOUBLE_FINAL_DEFAULT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 6L)), new SymbolReference(DOUBLE, "SUM_DOUBLE"))), Optional.of(new Constant(DOUBLE, 0.0))))) + .put("SUM_BIGINT_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) + .put("SUM_BIGINT_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_INT_CAST_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.empty()))) + .put("SUM_INT_CAST_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_TINYINT_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(TINYINT, "SUM_TINYINT"))), Optional.empty()))) + .put("SUM_TINYINT_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(TINYINT, "SUM_TINYINT"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_DECIMAL_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) + .put("SUM_DECIMAL_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.of(new Constant(createDecimalType(38, 1), Decimals.valueOf(new BigDecimal("0.0"))))))) + .put("SUM_LONG_DECIMAL_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 5L)), new Reference(BIGINT, "SUM_LONG_DECIMAL"))), Optional.empty()))) + .put("SUM_LONG_DECIMAL_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 5L)), new Reference(BIGINT, "SUM_LONG_DECIMAL"))), Optional.of(new Constant(createDecimalType(38, 18), Decimals.valueOf(new BigDecimal("0.000000000000000000"))))))) + .put("SUM_DOUBLE_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 6L)), new Reference(DOUBLE, "SUM_DOUBLE"))), Optional.empty()))) + .put("SUM_DOUBLE_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 6L)), new Reference(DOUBLE, "SUM_DOUBLE"))), Optional.of(new Constant(DOUBLE, 0.0))))) .buildOrThrow(), aggregation( singleGroupingSet("COL_BIGINT"), @@ -305,8 +305,8 @@ public void testPreAggregatesWithDefaultValues() SINGLE, exchange( project(ImmutableMap.of( - "VALUE_INT_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Cast(new Cast(new SymbolReference(BIGINT, "COL_BIGINT"), INTEGER), BIGINT))), Optional.empty())), - "VALUE_TINYINT_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Cast(new SymbolReference(TINYINT, "COL_TINYINT"), BIGINT))), Optional.empty()))), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Cast(new Cast(new Reference(BIGINT, "COL_BIGINT"), INTEGER), BIGINT))), Optional.empty())), + "VALUE_TINYINT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Cast(new Reference(TINYINT, "COL_TINYINT"), BIGINT))), Optional.empty()))), tableScan( "t", ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationColumns.java index 9d5fcb29e0cb..4a23f7e9891c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -47,7 +47,7 @@ public void testNotAllInputsReferenced() .on(p -> buildProjectedAggregation(p, symbol -> symbol.getName().equals("b"))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), aggregation( singleGroupingSet("key"), ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java index 9e7530a68978..3f397f014adb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -58,10 +58,10 @@ public void testNotAllInputsReferenced() SINGLE, strictProject( ImmutableMap.of( - "input", expression(new SymbolReference(BIGINT, "input")), - "key", expression(new SymbolReference(BIGINT, "key")), - "keyHash", expression(new SymbolReference(BIGINT, "keyHash")), - "mask", expression(new SymbolReference(BOOLEAN, "mask"))), + "input", expression(new Reference(BIGINT, "input")), + "key", expression(new Reference(BIGINT, "key")), + "keyHash", expression(new Reference(BIGINT, "keyHash")), + "mask", expression(new Reference(BOOLEAN, "mask"))), values("input", "key", "keyHash", "mask", "unused")))); } @@ -84,7 +84,7 @@ private AggregationNode buildAggregation(PlanBuilder planBuilder, Predicate sourceSymbols = ImmutableList.of(input, key, keyHash, mask, unused); return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .singleGroupingSet(key) - .addAggregation(avg, PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT), mask) + .addAggregation(avg, PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT), mask) .hashSymbol(keyHash) .source( planBuilder.values( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyColumns.java index 82da51a4b581..d158837db3ce 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyColumns.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -27,7 +27,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.apply; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; @@ -55,12 +55,12 @@ public void testRemoveUnusedApplyNode() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(subquerySymbol)))); }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values("a", "correlationSymbol"))); } @@ -85,17 +85,17 @@ public void testRemoveUnreferencedAssignments() ImmutableList.of(correlationSymbol), p.values(a, b, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(subquerySymbol)))); }) .matches( project( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "in_result_1", expression(new SymbolReference(BOOLEAN, "in_result_1"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "in_result_1", expression(new Reference(BOOLEAN, "in_result_1"))), apply( ImmutableList.of("correlation_symbol"), ImmutableMap.of("in_result_1", setExpression(new ApplyNode.In(new Symbol(UNKNOWN, "a"), new Symbol(UNKNOWN, "subquery_symbol")))), project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), "correlation_symbol", PlanMatchPattern.expression(new SymbolReference(BIGINT, "correlation_symbol"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), "correlation_symbol", PlanMatchPattern.expression(new Reference(BIGINT, "correlation_symbol"))), values("a", "b", "correlation_symbol")), node( FilterNode.class, @@ -119,18 +119,18 @@ public void testRemoveUnreferencedAssignments() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol1.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol1.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(subquerySymbol1, subquerySymbol2)))); }) .matches( project( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "in_result_1", expression(new SymbolReference(BOOLEAN, "in_result_1"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "in_result_1", expression(new Reference(BOOLEAN, "in_result_1"))), apply( ImmutableList.of("correlation_symbol"), ImmutableMap.of("in_result_1", setExpression(new ApplyNode.In(new Symbol(UNKNOWN, "a"), new Symbol(UNKNOWN, "subquery_symbol_1")))), values("a", "correlation_symbol"), project( - ImmutableMap.of("subquery_symbol_1", expression(new SymbolReference(BIGINT, "subquery_symbol_1"))), + ImmutableMap.of("subquery_symbol_1", expression(new Reference(BIGINT, "subquery_symbol_1"))), node( FilterNode.class, values("subquery_symbol_1", "subquery_symbol_2")))))); @@ -153,18 +153,18 @@ public void testPruneUnreferencedSubquerySymbol() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, unreferenced.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, unreferenced.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(unreferenced, subquerySymbol)))); }) .matches( project( - ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression(new SymbolReference(BIGINT, "correlation_symbol")), "in_result", PlanMatchPattern.expression(new SymbolReference(BOOLEAN, "in_result"))), + ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression(new Reference(BIGINT, "correlation_symbol")), "in_result", PlanMatchPattern.expression(new Reference(BOOLEAN, "in_result"))), apply( ImmutableList.of("correlation_symbol"), ImmutableMap.of("in_result", setExpression(new ApplyNode.In(new Symbol(UNKNOWN, "a"), new Symbol(UNKNOWN, "subquery_symbol")))), values("a", "correlation_symbol"), project( - ImmutableMap.of("subquery_symbol", PlanMatchPattern.expression(new SymbolReference(BIGINT, "subquery_symbol"))), + ImmutableMap.of("subquery_symbol", PlanMatchPattern.expression(new Reference(BIGINT, "subquery_symbol"))), node( FilterNode.class, values("unreferenced", "subquery_symbol")))))); @@ -187,17 +187,17 @@ public void testPruneUnreferencedInputSymbol() ImmutableList.of(correlationSymbol), p.values(a, unreferenced, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(subquerySymbol)))); }) .matches( project( - ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression(new SymbolReference(BIGINT, "correlation_symbol")), "in_result", PlanMatchPattern.expression(new SymbolReference(BOOLEAN, "in_result"))), + ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression(new Reference(BIGINT, "correlation_symbol")), "in_result", PlanMatchPattern.expression(new Reference(BOOLEAN, "in_result"))), apply( ImmutableList.of("correlation_symbol"), ImmutableMap.of("in_result", setExpression(new ApplyNode.In(new Symbol(UNKNOWN, "a"), new Symbol(UNKNOWN, "subquery_symbol")))), project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), "correlation_symbol", PlanMatchPattern.expression(new SymbolReference(BIGINT, "correlation_symbol"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), "correlation_symbol", PlanMatchPattern.expression(new Reference(BIGINT, "correlation_symbol"))), values("a", "unreferenced", "correlation_symbol")), node( FilterNode.class, @@ -220,7 +220,7 @@ public void testDoNotPruneUnreferencedUsedCorrelationSymbol() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(subquerySymbol)))); }) .doesNotFire(); @@ -262,7 +262,7 @@ public void testAllOutputsReferenced() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(subquerySymbol)))); }) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyCorrelation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyCorrelation.java index eeb26f8b6836..4a0d8e1f4c6f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyCorrelation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplyCorrelation.java @@ -15,13 +15,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.ApplyNode; import org.junit.jupiter.api.Test; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.apply; import static io.trino.sql.planner.assertions.PlanMatchPattern.setExpression; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -67,7 +67,7 @@ public void testAllCorrelationSymbolsReferencedInSubquery() ImmutableList.of(inputSymbol), p.values(a, inputSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), inputSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), inputSymbol.toSymbolReference()), p.values(subquerySymbol))); }) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplySourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplySourceColumns.java index fc03835b9f19..263bd16e44f2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplySourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneApplySourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.ApplyNode; @@ -53,7 +53,7 @@ public void testNotAllSubquerySymbolsReferenced() ImmutableMap.of("in_result", setExpression(new ApplyNode.In(new Symbol(UNKNOWN, "a"), new Symbol(UNKNOWN, "subquery_symbol_1")))), values("a"), project( - ImmutableMap.of("subquery_symbol_1", expression(new SymbolReference(BIGINT, "subquery_symbol_1"))), + ImmutableMap.of("subquery_symbol_1", expression(new Reference(BIGINT, "subquery_symbol_1"))), values("subquery_symbol_1", "subquery_symbol_2")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAssignUniqueIdColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAssignUniqueIdColumns.java index 9efdf8b48d9e..66401648dd41 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAssignUniqueIdColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneAssignUniqueIdColumns.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -45,7 +45,7 @@ public void testRemoveUnusedAssignUniqueId() }) .matches( strictProject( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values("a", "b"))); } @@ -65,11 +65,11 @@ public void testNotAllInputsReferenced() }) .matches( strictProject( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), "unique_id", PlanMatchPattern.expression(new SymbolReference(BIGINT, "unique_id"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), "unique_id", PlanMatchPattern.expression(new Reference(BIGINT, "unique_id"))), assignUniqueId( "unique_id", strictProject( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinColumns.java index 0787a4febddb..54d70c6b07dc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinColumns.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -25,9 +25,9 @@ import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.correlatedJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; @@ -57,7 +57,7 @@ public void testRemoveUnusedCorrelatedJoinNode() }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values("a", "correlationSymbol"))); // retain input of LEFT join @@ -72,12 +72,12 @@ public void testRemoveUnusedCorrelatedJoinNode() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), LEFT, - new ComparisonExpression(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(1, b))); }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values("a", "correlationSymbol"))); // retain subquery of INNER join @@ -94,7 +94,7 @@ public void testRemoveUnusedCorrelatedJoinNode() }) .matches( project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("b"))); // retain subquery of RIGHT join @@ -108,12 +108,12 @@ public void testRemoveUnusedCorrelatedJoinNode() ImmutableList.of(), p.values(1, a), RIGHT, - new ComparisonExpression(GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), + new Comparison(GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), p.values(b))); }) .matches( project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("b"))); } @@ -132,19 +132,19 @@ public void testPruneUnreferencedSubquerySymbol() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), LEFT, - new ComparisonExpression(GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), + new Comparison(GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), p.filter( - new ComparisonExpression(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(5, b, c)))); }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), correlatedJoin( ImmutableList.of("correlation_symbol"), values("a", "correlation_symbol"), project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), node( FilterNode.class, values("b", "c")))))); @@ -164,18 +164,18 @@ public void testPruneUnreferencedInputSymbol() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), LEFT, - new ComparisonExpression(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()), p.filter( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, b.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN_OR_EQUAL, b.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(b)))); }) .matches( project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), correlatedJoin( ImmutableList.of("correlation_symbol"), project( - ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression(new SymbolReference(BIGINT, "correlation_symbol"))), + ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression(new Reference(BIGINT, "correlation_symbol"))), values("a", "correlation_symbol")), node( FilterNode.class, @@ -196,7 +196,7 @@ public void testDoNotPruneUnreferencedCorrelationSymbol() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), LEFT, - TRUE_LITERAL, + TRUE, p.values(b))); }) .doesNotFire(); @@ -216,9 +216,9 @@ public void testAllOutputsReferenced() ImmutableList.of(correlationSymbol), p.values(a, correlationSymbol), LEFT, - TRUE_LITERAL, + TRUE, p.filter( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, b.toSymbolReference(), correlationSymbol.toSymbolReference()), + new Comparison(GREATER_THAN_OR_EQUAL, b.toSymbolReference(), correlationSymbol.toSymbolReference()), p.values(b)))); }) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinCorrelation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinCorrelation.java index 56262f156cac..02d58a9c4456 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinCorrelation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinCorrelation.java @@ -14,12 +14,12 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.correlatedJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -56,7 +56,7 @@ public void testAllCorrelationSymbolsReferencedInSubquery() ImmutableList.of(inputSymbol), p.values(inputSymbol), p.filter( - new ComparisonExpression(GREATER_THAN, subquerySymbol.toSymbolReference(), inputSymbol.toSymbolReference()), + new Comparison(GREATER_THAN, subquerySymbol.toSymbolReference(), inputSymbol.toSymbolReference()), p.values(subquerySymbol))); }) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java index 99de3000aa57..6ec29aa6f2a4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -21,7 +21,7 @@ import io.trino.plugin.tpch.TpchTransactionHandle; import io.trino.spi.type.BigintType; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -139,7 +139,7 @@ public void testDoesNotFireOnNestedNonCountAggregate() AggregationNode inner = p.aggregation((a) -> a .addAggregation( totalPrice, - PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(DOUBLE, "totalprice"))), + PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(DOUBLE, "totalprice"))), ImmutableList.of(DOUBLE)) .globalGrouping() .source( @@ -156,7 +156,7 @@ public void testDoesNotFireOnNestedNonCountAggregate() return p.aggregation((a) -> a .addAggregation( p.symbol("sum_outer", DOUBLE), - PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "sum_inner"))), + PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "sum_inner"))), ImmutableList.of(DOUBLE)) .globalGrouping() .source(inner)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctAggregation.java index d30faac1e7c9..f4be945f2f92 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctAggregation.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -54,7 +54,7 @@ public void testNonPruning() Symbol a = p.symbol("a"); AggregationNode child = p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping() - .addAggregation(p.symbol("sum", BIGINT), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum", BIGINT), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .source(p.values(1, a))); return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping() diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctLimitSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctLimitSourceColumns.java index 4240521df0b9..ba23b0e281d7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctLimitSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneDistinctLimitSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -48,7 +48,7 @@ public void testPruneInputColumn() 5, ImmutableList.of("a"), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")))); tester().assertThat(new PruneDistinctLimitSourceColumns()) @@ -68,7 +68,7 @@ public void testPruneInputColumn() ImmutableList.of("a"), "hash_symbol", strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "hash_symbol", expression(new SymbolReference(BIGINT, "hash_symbol"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "hash_symbol", expression(new Reference(BIGINT, "hash_symbol"))), values("a", "b", "hash_symbol")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneEnforceSingleRowColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneEnforceSingleRowColumns.java index 9347478a19fb..eb17725eedfb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneEnforceSingleRowColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneEnforceSingleRowColumns.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -42,10 +42,10 @@ public void testPruneInputColumn() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), enforceSingleRow( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExceptSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExceptSourceColumns.java index 95ecc56ce833..75a21f33ab6c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExceptSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExceptSourceColumns.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -47,7 +47,7 @@ public void testPruneOneChild() }) .matches(except( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")), values("c"))); } @@ -70,10 +70,10 @@ public void testPruneAllChildren() }) .matches(except( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")), strictProject( - ImmutableMap.of("c", expression(new SymbolReference(BIGINT, "c"))), + ImmutableMap.of("c", expression(new Reference(BIGINT, "c"))), values("c", "d")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeColumns.java index 9ea996a485d5..1062bb5de288 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeColumns.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.spi.connector.SortOrder; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -129,7 +129,7 @@ public void testPruneUnreferencedSymbol() }) .matches( project( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), exchange( REMOTE, GATHER, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeSourceColumns.java index 1355eb4f5b7f..7b38edfc4747 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExchangeSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -49,7 +49,7 @@ public void testPruneOneChild() exchange( values(ImmutableList.of("b")), strictProject( - ImmutableMap.of("c_1", expression(new SymbolReference(BIGINT, "c_1"))), + ImmutableMap.of("c_1", expression(new Reference(BIGINT, "c_1"))), values(ImmutableList.of("c_1", "c_2"))))); } @@ -73,10 +73,10 @@ public void testPruneAllChildren() .matches( exchange( strictProject( - ImmutableMap.of("b_1", expression(new SymbolReference(BIGINT, "b_1"))), + ImmutableMap.of("b_1", expression(new Reference(BIGINT, "b_1"))), values(ImmutableList.of("b_1", "b_2"))), strictProject( - ImmutableMap.of("c_1", expression(new SymbolReference(BIGINT, "c_1"))), + ImmutableMap.of("c_1", expression(new Reference(BIGINT, "c_1"))), values(ImmutableList.of("c_1", "c_2"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExplainAnalyzeSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExplainAnalyzeSourceColumns.java index fbf04081a1dc..dfc2757302d5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExplainAnalyzeSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneExplainAnalyzeSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.ExplainAnalyzeNode; @@ -46,7 +46,7 @@ public void testNotAllOutputsReferenced() .matches( node(ExplainAnalyzeNode.class, strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneFilterColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneFilterColumns.java index 84f8d365665a..2407aff36b38 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneFilterColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneFilterColumns.java @@ -14,9 +14,9 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -30,7 +30,7 @@ import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; @@ -46,11 +46,11 @@ public void testNotAllInputsReferenced() .on(p -> buildProjectedFilter(p, symbol -> symbol.getName().equals("b"))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b"))), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), strictProject( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b"))), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b"))), values("a", "b"))))); } @@ -77,7 +77,7 @@ private ProjectNode buildProjectedFilter(PlanBuilder planBuilder, Predicate buildProjectedIndexSource(p, symbol -> symbol.getName().equals("orderkey"))) .matches( strictProject( - ImmutableMap.of("x", expression(new SymbolReference(BIGINT, "orderkey"))), + ImmutableMap.of("x", expression(new Reference(BIGINT, "orderkey"))), constrainedIndexSource( "orders", ImmutableMap.of("orderkey", "orderkey")))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneIntersectSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneIntersectSourceColumns.java index 678d4c2b854e..2d6f3aa58a40 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneIntersectSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneIntersectSourceColumns.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -49,7 +49,7 @@ public void testPruneOneChild() }) .matches(intersect( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")), values("c"), values("d"))); @@ -76,13 +76,13 @@ public void testPruneAllChildren() }) .matches(intersect( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")), strictProject( - ImmutableMap.of("c", expression(new SymbolReference(BIGINT, "c"))), + ImmutableMap.of("c", expression(new Reference(BIGINT, "c"))), values("c", "d")), strictProject( - ImmutableMap.of("e", expression(new SymbolReference(BIGINT, "e"))), + ImmutableMap.of("e", expression(new Reference(BIGINT, "e"))), values("e", "f")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java index 5dd1cbfdc547..204d1eb7af31 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java @@ -16,9 +16,9 @@ import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -34,7 +34,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -51,13 +51,13 @@ public void testNotAllInputsReferenced() .matches( join(INNER, builder -> builder .equiCriteria("leftKey", "rightKey") - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "leftValue"), new Constant(INTEGER, 5L))) + .filter(new Comparison(GREATER_THAN, new Reference(INTEGER, "leftValue"), new Constant(INTEGER, 5L))) .left(values("leftKey", "leftKeyHash", "leftValue")) .right( strictProject( ImmutableMap.of( - "rightKey", PlanMatchPattern.expression(new SymbolReference(BIGINT, "rightKey")), - "rightKeyHash", PlanMatchPattern.expression(new SymbolReference(BIGINT, "rightKeyHash"))), + "rightKey", PlanMatchPattern.expression(new Reference(BIGINT, "rightKey")), + "rightKeyHash", PlanMatchPattern.expression(new Reference(BIGINT, "rightKeyHash"))), values("rightKey", "rightKeyHash", "rightValue"))))); } @@ -117,7 +117,7 @@ private static PlanNode buildJoin(PlanBuilder p, Predicate joinOutputFil rightOutputs.stream() .filter(joinOutputFilter) .collect(toImmutableList()), - Optional.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "leftValue"), new Constant(INTEGER, 5L))), + Optional.of(new Comparison(GREATER_THAN, new Reference(INTEGER, "leftValue"), new Constant(INTEGER, 5L))), Optional.of(leftKeyHash), Optional.of(rightKeyHash)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinColumns.java index cdda96c06f38..3ea472350701 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneJoinColumns.java @@ -16,7 +16,7 @@ import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -47,7 +47,7 @@ public void testNotAllOutputsReferenced() .on(p -> buildProjectedJoin(p, symbol -> symbol.getName().equals("rightValue"))) .matches( strictProject( - ImmutableMap.of("rightValue", expression(new SymbolReference(BIGINT, "rightValue"))), + ImmutableMap.of("rightValue", expression(new Reference(BIGINT, "rightValue"))), join(INNER, builder -> builder .equiCriteria("leftKey", "rightKey") .left(values(ImmutableList.of("leftKey", "leftValue"))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneLimitColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneLimitColumns.java index 3fd8709e8e12..fb2db5aef4b1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneLimitColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneLimitColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -48,11 +48,11 @@ public void testNotAllInputsReferenced() .on(p -> buildProjectedLimit(p, symbol -> symbol.getName().equals("b"))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), limit( 1, strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))))); } @@ -82,7 +82,7 @@ public void testDoNotPruneTiesResolvingSymbols() 1, ImmutableList.of(sort("a", ASCENDING, FIRST)), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } @@ -107,7 +107,7 @@ public void testDoNotPrunePreSortedInputSymbols() false, ImmutableList.of("a"), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java index fe055ab5377c..2fd737e0f357 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -46,7 +46,7 @@ public void testMarkerSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("key2", expression(new SymbolReference(BIGINT, "key"))), + ImmutableMap.of("key2", expression(new Reference(BIGINT, "key"))), values(ImmutableList.of("key", "unused")))); } @@ -69,12 +69,12 @@ public void testSourceSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("mark", expression(new SymbolReference(BOOLEAN, "mark"))), + ImmutableMap.of("mark", expression(new Reference(BOOLEAN, "mark"))), markDistinct("mark", ImmutableList.of("key"), "hash", strictProject( ImmutableMap.of( - "key", expression(new SymbolReference(BIGINT, "key")), - "hash", expression(new SymbolReference(BIGINT, "hash"))), + "key", expression(new Reference(BIGINT, "key")), + "hash", expression(new Reference(BIGINT, "hash"))), values(ImmutableList.of("key", "hash", "unused")))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java index 3bfc50a703be..179b54f3f2fe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.SchemaTableName; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.MergeWriterNode; @@ -53,8 +53,8 @@ public void testPruneInputColumn() MergeWriterNode.class, strictProject( ImmutableMap.of( - "row_id", expression(new SymbolReference(BIGINT, "row_id")), - "merge_row", expression(new SymbolReference(BIGINT, "merge_row"))), + "row_id", expression(new Reference(BIGINT, "row_id")), + "merge_row", expression(new Reference(BIGINT, "merge_row"))), values("a", "merge_row", "row_id")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOffsetColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOffsetColumns.java index ad261413c03d..297a4101474e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOffsetColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOffsetColumns.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -43,11 +43,11 @@ public void testNotAllInputsReferenced() .on(p -> buildProjectedOffset(p, symbol -> symbol.getName().equals("b"))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), offset( 1, strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java index 855d3840a2c3..ffc46e71da6f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.Metadata; import io.trino.spi.connector.SortOrder; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -79,7 +79,7 @@ private AggregationNode buildAggregation(PlanBuilder planBuilder) .singleGroupingSet(key) .addAggregation(avg, PlanBuilder.aggregation( "avg", - ImmutableList.of(new SymbolReference(BIGINT, "input")), + ImmutableList.of(new Reference(BIGINT, "input")), new OrderingScheme( ImmutableList.of(new Symbol(UNKNOWN, "input")), ImmutableMap.of(new Symbol(UNKNOWN, "input"), SortOrder.ASC_NULLS_LAST))), @@ -87,7 +87,7 @@ private AggregationNode buildAggregation(PlanBuilder planBuilder) mask) .addAggregation(arrayAgg, PlanBuilder.aggregation( "array_agg", - ImmutableList.of(new SymbolReference(BIGINT, "input")), + ImmutableList.of(new Reference(BIGINT, "input")), new OrderingScheme( ImmutableList.of(new Symbol(UNKNOWN, "input")), ImmutableMap.of(new Symbol(UNKNOWN, "input"), SortOrder.ASC_NULLS_LAST))), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOutputSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOutputSourceColumns.java index b91455166525..785a86463cb4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOutputSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOutputSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -45,7 +45,7 @@ public void testNotAllOutputsReferenced() strictOutput( ImmutableList.of("b"), strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java index 690a363ad502..b7e5c19b714c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java @@ -17,9 +17,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.metadata.ResolvedFunction; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -39,8 +39,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.patternRecognition; import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; @@ -75,11 +75,11 @@ public void testRemovePatternRecognitionNode() .rowsPerMatch(ALL_WITH_UNMATCHED) .skipTo(PAST_LAST) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))); // pattern recognition in window @@ -91,11 +91,11 @@ public void testRemovePatternRecognitionNode() .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .skipTo(NEXT) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))); // unreferenced window functions and measures @@ -106,7 +106,7 @@ public void testRemovePatternRecognitionNode() .addWindowFunction(p.symbol("rank"), new WindowNode.Function(rank, ImmutableList.of(), DEFAULT_FRAME, false)) .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a")))) @@ -114,11 +114,11 @@ public void testRemovePatternRecognitionNode() .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .skipTo(NEXT) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))); } @@ -135,7 +135,7 @@ public void testPruneUnreferencedWindowFunctionAndSources() .addWindowFunction(p.symbol("lag", BIGINT), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b", BIGINT).toSymbolReference()), DEFAULT_FRAME, false)) .addMeasure( p.symbol("measure", BIGINT), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(BIGINT, "a")))) @@ -143,15 +143,15 @@ public void testPruneUnreferencedWindowFunctionAndSources() .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .skipTo(NEXT) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT)))))) .matches( strictProject( - ImmutableMap.of("measure", expression(new SymbolReference(BIGINT, "measure"))), + ImmutableMap.of("measure", expression(new Reference(BIGINT, "measure"))), patternRecognition(builder -> builder .addMeasure( "measure", - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(BIGINT, "a"))), @@ -167,9 +167,9 @@ public void testPruneUnreferencedWindowFunctionAndSources() Optional.empty())) .skipTo(NEXT) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } @@ -194,7 +194,7 @@ public void testPruneUnreferencedMeasureAndSources() .addWindowFunction(p.symbol("lag"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b").toSymbolReference()), frame, false)) .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a")))) @@ -202,20 +202,20 @@ public void testPruneUnreferencedMeasureAndSources() .frame(frame) .skipTo(NEXT) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("lag", expression(new SymbolReference(BIGINT, "lag"))), + ImmutableMap.of("lag", expression(new Reference(BIGINT, "lag"))), patternRecognition(builder -> builder .addFunction("lag", windowFunction("lag", ImmutableList.of("b"), frame)) .rowsPerMatch(WINDOW) .frame(frame) .skipTo(NEXT) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))))); } @@ -231,7 +231,7 @@ public void testDoNotPruneVariableDefinitionSources() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "pointer"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "pointer"), new Constant(BIGINT, 0L)), ImmutableMap.of(new Symbol(BIGINT, "pointer"), new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(BIGINT, "a")))) @@ -243,12 +243,12 @@ public void testDoNotPruneVariableDefinitionSources() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "pointer"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "pointer"), new Constant(BIGINT, 0L)), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(BIGINT, "a")))), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); // inputs "a", "b" are used as aggregation arguments @@ -261,11 +261,11 @@ public void testDoNotPruneVariableDefinitionSources() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "agg"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "agg"), new Constant(BIGINT, 5L)), ImmutableMap.of(new Symbol(BIGINT, "agg"), new AggregationValuePointer( maxBy, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")), + ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), Optional.empty(), Optional.empty()))) .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT), p.symbol("c", BIGINT)))))) @@ -276,15 +276,15 @@ public void testDoNotPruneVariableDefinitionSources() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "agg"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "agg"), new Constant(BIGINT, 5L)), ImmutableMap.of("agg", new AggregationValuePointer( maxBy, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")), + ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), Optional.empty(), Optional.empty()))), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b"))), values("a", "b", "c"))))); } @@ -298,17 +298,17 @@ public void testDoNotPruneReferencedInputs() p.patternRecognition(builder -> builder .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), patternRecognition(builder -> builder .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } @@ -322,7 +322,7 @@ public void testDoNotPrunePartitionBySymbols() p.patternRecognition(builder -> builder .partitionBy(ImmutableList.of(p.symbol("a"))) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( @@ -330,9 +330,9 @@ public void testDoNotPrunePartitionBySymbols() patternRecognition(builder -> builder .specification(specification(ImmutableList.of("a"), ImmutableList.of(), ImmutableMap.of())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } @@ -347,7 +347,7 @@ public void testDoNotPruneOrderBySymbols() .orderBy(new OrderingScheme(ImmutableList.of(p.symbol("a")), ImmutableMap.of(p.symbol("a"), ASC_NULLS_LAST))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( @@ -356,9 +356,9 @@ public void testDoNotPruneOrderBySymbols() .specification(specification(ImmutableList.of(), ImmutableList.of("a"), ImmutableMap.of("a", ASC_NULLS_LAST))) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } @@ -374,11 +374,11 @@ public void testDoNotPruneCommonBaseFrameSymbols() .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), FOLLOWING, Optional.of(p.symbol("a")), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("measure", expression(new SymbolReference(BIGINT, "measure"))), + ImmutableMap.of("measure", expression(new Reference(BIGINT, "measure"))), patternRecognition(builder -> builder .addMeasure("measure", new Constant(INTEGER, 1L), BIGINT) .rowsPerMatch(WINDOW) @@ -391,9 +391,9 @@ public void testDoNotPruneCommonBaseFrameSymbols() Optional.of(new Symbol(UNKNOWN, "a")), Optional.empty())) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } @@ -410,7 +410,7 @@ public void testDoNotPruneUnreferencedUsedOutputs() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "value"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "value"), new Constant(INTEGER, 0L)), ImmutableMap.of(new Symbol(INTEGER, "value"), new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol(UNKNOWN, "a")))) @@ -430,15 +430,15 @@ public void testPruneAndMeasures() .addMeasure(p.symbol("measure"), new Constant(INTEGER, 1L)) .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"), p.symbol("b")))))) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "b", expression(new Reference(BIGINT, "b"))), patternRecognition(builder -> builder .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), values("a", "b")))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePatternRecognitionSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePatternRecognitionSourceColumns.java index c0a09bbec04f..c24ada9f0c50 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePatternRecognitionSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePatternRecognitionSourceColumns.java @@ -16,9 +16,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -30,8 +30,8 @@ import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.patternRecognition; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -49,13 +49,13 @@ public void testPruneUnreferencedInput() .on(p -> p.patternRecognition(builder -> builder .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))) .matches( patternRecognition(builder -> builder .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL), + .addVariableDefinition(new IrLabel("X"), TRUE), strictProject( ImmutableMap.of(), values("a")))); @@ -68,7 +68,7 @@ public void testDoNotPruneInputsWithAllRowsPerMatch() .on(p -> p.patternRecognition(builder -> builder .rowsPerMatch(ALL_SHOW_EMPTY) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))) .doesNotFire(); } @@ -81,7 +81,7 @@ public void testDoNotPrunePartitionByInputs() .partitionBy(ImmutableList.of(p.symbol("a"))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))) .doesNotFire(); } @@ -94,7 +94,7 @@ public void testDoNotPruneOrderByInputs() .orderBy(new OrderingScheme(ImmutableList.of(p.symbol("a")), ImmutableMap.of(p.symbol("a"), ASC_NULLS_LAST))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))) .doesNotFire(); } @@ -106,13 +106,13 @@ public void testDoNotPruneMeasureInputs() .on(p -> p.patternRecognition(builder -> builder .addMeasure( p.symbol("measure"), - new SymbolReference(BIGINT, "pointer"), + new Reference(BIGINT, "pointer"), ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(UNKNOWN, "a")))) .rowsPerMatch(ONE) .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))) .doesNotFire(); } @@ -126,7 +126,7 @@ public void testDoNotPruneVariableDefinitionInputs() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "pointer"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "pointer"), new Constant(INTEGER, 0L)), ImmutableMap.of(new Symbol(INTEGER, "pointer"), new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(INTEGER, "a")))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneProjectColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneProjectColumns.java index f5997d55d6c8..f85b88bb2058 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneProjectColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneProjectColumns.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -43,9 +43,9 @@ public void testNotAllOutputsReferenced() }) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneRowNumberColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneRowNumberColumns.java index a33a7ebe6af2..15cc786a1945 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneRowNumberColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneRowNumberColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -48,7 +48,7 @@ public void testRowNumberSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values(ImmutableList.of("a")))); // partitioning is present, no limit per partition @@ -62,7 +62,7 @@ public void testRowNumberSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values(ImmutableList.of("a")))); // no partitioning, limit per partition is present @@ -76,7 +76,7 @@ public void testRowNumberSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), limit( 5, values(ImmutableList.of("a"))))); @@ -93,13 +93,13 @@ public void testRowNumberSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of("a")) .maxRowCountPerPartition(Optional.of(5)), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values(ImmutableList.of("a", "b")))))); } @@ -145,7 +145,7 @@ public void testSourceSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("row_number", expression(new SymbolReference(BIGINT, "row_number"))), + ImmutableMap.of("row_number", expression(new Reference(BIGINT, "row_number"))), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSampleColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSampleColumns.java index 8b19bd36625c..4d344723bc51 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSampleColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSampleColumns.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.SampleNode; @@ -42,10 +42,10 @@ public void testNotAllInputsReferenced() p.values(p.symbol("a"), p.symbol("b"))))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), node(SampleNode.class, strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java index 97869aa0293a..357d79cadc22 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -45,7 +45,7 @@ public void testSemiJoinNotNeeded() .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("leftValue"))) .matches( strictProject( - ImmutableMap.of("leftValue", expression(new SymbolReference(BIGINT, "leftValue"))), + ImmutableMap.of("leftValue", expression(new Reference(BIGINT, "leftValue"))), values("leftKey", "leftKeyHash", "leftValue"))); } @@ -72,12 +72,12 @@ public void testValueNotNeeded() .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("match"))) .matches( strictProject( - ImmutableMap.of("match", expression(new SymbolReference(BOOLEAN, "match"))), + ImmutableMap.of("match", expression(new Reference(BOOLEAN, "match"))), semiJoin("leftKey", "rightKey", "match", strictProject( ImmutableMap.of( - "leftKey", expression(new SymbolReference(BIGINT, "leftKey")), - "leftKeyHash", expression(new SymbolReference(BIGINT, "leftKeyHash"))), + "leftKey", expression(new Reference(BIGINT, "leftKey")), + "leftKeyHash", expression(new Reference(BIGINT, "leftKeyHash"))), values("leftKey", "leftKeyHash", "leftValue")), values("rightKey")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java index 1dc7cc4c036d..fbedfa66d432 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -46,8 +46,8 @@ public void testNotAllColumnsReferenced() values("leftKey"), strictProject( ImmutableMap.of( - "rightKey", expression(new SymbolReference(BIGINT, "rightKey")), - "rightKeyHash", expression(new SymbolReference(BIGINT, "rightKeyHash"))), + "rightKey", expression(new Reference(BIGINT, "rightKey")), + "rightKeyHash", expression(new Reference(BIGINT, "rightKeyHash"))), values("rightKey", "rightKeyHash", "rightValue")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSortColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSortColumns.java index bac3169072b5..9d4f39fdad0c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSortColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneSortColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -58,11 +58,11 @@ public void testNotAllInputsReferenced() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), sort( ImmutableList.of(sort("a", ASCENDING, FIRST)), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java index 812ebea435a5..febba3e59bac 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -48,7 +48,7 @@ public void testNotAllInputsReferenced() ImmutableList.of("a"), ImmutableList.of("column_a"), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java index f7dc14cb8f6d..1f68b0497f11 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -125,7 +125,7 @@ public void testReferencedPassThroughOutputs() .source(p.values(a, b)))); }) .matches(project( - ImmutableMap.of("y", expression(new SymbolReference(BIGINT, "y")), "b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("y", expression(new Reference(BIGINT, "y")), "b", expression(new Reference(BIGINT, "b"))), tableFunctionProcessor(builder -> builder .name("test_function") .properOutputs(ImmutableList.of("x", "y")) @@ -213,7 +213,7 @@ public void testMultipleTableArguments() .source(p.values(a, b, c, d)))); }) .matches(project( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), tableFunctionProcessor(builder -> builder .name("test_function") .properOutputs(ImmutableList.of("proper")) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java index 29a283a1d8bc..719a5c09076d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -82,12 +82,12 @@ public void testPruneUnreferencedSymbol() .hashSymbol("hash"), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "c", expression(new SymbolReference(BIGINT, "c")), - "d", expression(new SymbolReference(BIGINT, "d")), - "hash", expression(new SymbolReference(BIGINT, "hash")), - "marker", expression(new SymbolReference(BIGINT, "marker"))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "c", expression(new Reference(BIGINT, "c")), + "d", expression(new Reference(BIGINT, "d")), + "hash", expression(new Reference(BIGINT, "hash")), + "marker", expression(new Reference(BIGINT, "marker"))), values("a", "b", "c", "d", "unreferenced", "hash", "marker")))); } @@ -168,15 +168,15 @@ public void testMultipleSources() "f", "marker3")), project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), - "b", expression(new SymbolReference(BIGINT, "b")), - "c", expression(new SymbolReference(BIGINT, "c")), - "d", expression(new SymbolReference(BIGINT, "d")), - "e", expression(new SymbolReference(BIGINT, "e")), - "f", expression(new SymbolReference(BIGINT, "f")), - "marker1", expression(new SymbolReference(BIGINT, "marker1")), - "marker2", expression(new SymbolReference(BIGINT, "marker2")), - "marker3", expression(new SymbolReference(BIGINT, "marker3"))), + "a", expression(new Reference(BIGINT, "a")), + "b", expression(new Reference(BIGINT, "b")), + "c", expression(new Reference(BIGINT, "c")), + "d", expression(new Reference(BIGINT, "d")), + "e", expression(new Reference(BIGINT, "e")), + "f", expression(new Reference(BIGINT, "f")), + "marker1", expression(new Reference(BIGINT, "marker1")), + "marker2", expression(new Reference(BIGINT, "marker2")), + "marker3", expression(new Reference(BIGINT, "marker3"))), values("a", "b", "c", "d", "e", "f", "marker1", "marker2", "marker3", "unreferenced")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java index 21ff27f198bf..fe85b75174e6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -30,7 +30,7 @@ import io.trino.spi.expression.Variable; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -76,7 +76,7 @@ orderdate, new TpchColumnHandle(orderdate.getName(), DATE), }) .matches( strictProject( - ImmutableMap.of("x_", PlanMatchPattern.expression(new SymbolReference(DOUBLE, "totalprice_"))), + ImmutableMap.of("x_", PlanMatchPattern.expression(new Reference(DOUBLE, "totalprice_"))), strictTableScan("orders", ImmutableMap.of("totalprice_", "totalprice")))); } @@ -103,7 +103,7 @@ public void testPruneEnforcedConstraint() }) .matches( strictProject( - Map.of("X", PlanMatchPattern.expression(new SymbolReference(DOUBLE, "TOTALPRICE"))), + Map.of("X", PlanMatchPattern.expression(new Reference(DOUBLE, "TOTALPRICE"))), strictConstrainedTableScan( "orders", Map.of("TOTALPRICE", "totalprice"), @@ -117,7 +117,7 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneTableScanColumns(tester().getMetadata())) .on(p -> p.project( - Assignments.of(p.symbol("y"), new SymbolReference(BIGINT, "x")), + Assignments.of(p.symbol("y"), new Reference(BIGINT, "x")), p.tableScan( ImmutableList.of(p.symbol("x")), ImmutableMap.of(p.symbol("x"), new TestingColumnHandle("x"))))) @@ -162,7 +162,7 @@ public void testPushColumnPruningProjection() }) .matches( strictProject( - ImmutableMap.of("expr", PlanMatchPattern.expression(new SymbolReference(BIGINT, "COLB"))), + ImmutableMap.of("expr", PlanMatchPattern.expression(new Reference(BIGINT, "COLB"))), tableScan( new MockConnectorTableHandle( testSchemaTable, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableWriterSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableWriterSourceColumns.java index 1e072769efd1..7fb8aa9b868d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableWriterSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableWriterSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -50,7 +50,7 @@ public void testNotAllInputsReferenced() ImmutableList.of("a"), ImmutableList.of("column_a"), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")))); } @@ -106,7 +106,7 @@ public void testDoNotPruneStatisticAggregationSymbols() Optional.empty(), Optional.of( p.statisticAggregations( - ImmutableMap.of(aggregation, p.aggregation(PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(BIGINT, "argument"))), ImmutableList.of(BIGINT))), + ImmutableMap.of(aggregation, p.aggregation(PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(BIGINT, "argument"))), ImmutableList.of(BIGINT))), ImmutableList.of(group))), Optional.of(StatisticAggregationsDescriptor.empty()), p.values(a, group, argument)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNColumns.java index dea92c5d289b..0fcea6d499b7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNColumns.java @@ -16,7 +16,7 @@ import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -48,12 +48,12 @@ public void testNotAllInputsReferenced() .on(p -> buildProjectedTopN(p, symbol -> symbol.getName().equals("b"))) .matches( strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), topN( COUNT, ImmutableList.of(sort("b", ASCENDING, FIRST)), strictProject( - ImmutableMap.of("b", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("b", expression(new Reference(BIGINT, "b"))), values("a", "b"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java index 1cef8e5b3208..28216993db95 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.SortOrder; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher; @@ -127,7 +127,7 @@ public void testSourceSymbolNotReferenced() }) .matches( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a")), "ranking", expression(new SymbolReference(BIGINT, "ranking"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a")), "ranking", expression(new Reference(BIGINT, "ranking"))), topNRanking( pattern -> pattern .specification( @@ -137,7 +137,7 @@ public void testSourceSymbolNotReferenced() .rankingType(ROW_NUMBER) .maxRankingPerPartition(5), strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b"))) .withAlias("ranking", new TopNRankingSymbolMatcher()))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnionSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnionSourceColumns.java index 8c02e79cc825..99f5492417bb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnionSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnionSourceColumns.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -49,7 +49,7 @@ public void testPruneOneChild() }) .matches(union( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")), values("c"), values("d"))); @@ -76,13 +76,13 @@ public void testPruneAllChildren() }) .matches(union( strictProject( - ImmutableMap.of("a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", expression(new Reference(BIGINT, "a"))), values("a", "b")), strictProject( - ImmutableMap.of("c", expression(new SymbolReference(BIGINT, "c"))), + ImmutableMap.of("c", expression(new Reference(BIGINT, "c"))), values("c", "d")), strictProject( - ImmutableMap.of("e", expression(new SymbolReference(BIGINT, "e"))), + ImmutableMap.of("e", expression(new Reference(BIGINT, "e"))), values("e", "f")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestColumns.java index b4a58505fefa..8db8f9026c6e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -55,7 +55,7 @@ public void testPruneOrdinalitySymbol() }) .matches( strictProject( - ImmutableMap.of("replicate_symbol", expression(new SymbolReference(BIGINT, "replicate_symbol")), "unnested_symbol", expression(new SymbolReference(BIGINT, "unnested_symbol"))), + ImmutableMap.of("replicate_symbol", expression(new Reference(BIGINT, "replicate_symbol")), "unnested_symbol", expression(new Reference(BIGINT, "unnested_symbol"))), unnest( ImmutableList.of("replicate_symbol"), ImmutableList.of(unnestMapping("unnest_symbol", ImmutableList.of("unnested_symbol"))), @@ -84,7 +84,7 @@ public void testPruneReplicateSymbol() }) .matches( strictProject( - ImmutableMap.of("unnested_symbol", expression(new SymbolReference(BIGINT, "unnested_symbol")), "ordinality_symbol", expression(new SymbolReference(BIGINT, "ordinality_symbol"))), + ImmutableMap.of("unnested_symbol", expression(new Reference(BIGINT, "unnested_symbol")), "ordinality_symbol", expression(new Reference(BIGINT, "ordinality_symbol"))), unnest( ImmutableList.of(), ImmutableList.of(unnestMapping("unnest_symbol", ImmutableList.of("unnested_symbol"))), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java index e6881d207553..b07b9e46338a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java @@ -15,7 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.UnnestNode.Mapping; @@ -50,7 +50,7 @@ public void testNotAllInputsReferenced() ImmutableList.of("replicate_symbol"), ImmutableList.of(unnestMapping("unnest_symbol", ImmutableList.of("unnested_symbol"))), strictProject( - ImmutableMap.of("replicate_symbol", expression(new SymbolReference(BIGINT, "replicate_symbol")), "unnest_symbol", expression(new SymbolReference(BIGINT, "unnest_symbol"))), + ImmutableMap.of("replicate_symbol", expression(new Reference(BIGINT, "replicate_symbol")), "unnest_symbol", expression(new Reference(BIGINT, "unnest_symbol"))), values("replicate_symbol", "unnest_symbol", "unused_symbol")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java index a81a64bd4d45..830026d3cbe7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java @@ -19,8 +19,8 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -42,7 +42,7 @@ public void testNotAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), new SymbolReference(INTEGER, "x")), + Assignments.of(p.symbol("y"), new Reference(INTEGER, "x")), p.values( ImmutableList.of(p.symbol("unused"), p.symbol("x")), ImmutableList.of( @@ -50,7 +50,7 @@ public void testNotAllOutputsReferenced() ImmutableList.of(new Constant(INTEGER, 3L), new Constant(INTEGER, 4L)))))) .matches( project( - ImmutableMap.of("y", PlanMatchPattern.expression(new SymbolReference(INTEGER, "x"))), + ImmutableMap.of("y", PlanMatchPattern.expression(new Reference(INTEGER, "x"))), values( ImmutableList.of("x"), ImmutableList.of( @@ -64,7 +64,7 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), new SymbolReference(BIGINT, "x")), + Assignments.of(p.symbol("y"), new Reference(BIGINT, "x")), p.values(p.symbol("x")))) .doesNotFire(); } @@ -105,7 +105,7 @@ public void testDoNotPruneWhenValuesExpressionIsNotRow() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new SymbolReference(INTEGER, "x")), + Assignments.of(p.symbol("x"), new Reference(INTEGER, "x")), p.valuesOfExpressions( ImmutableList.of(p.symbol("x"), p.symbol("y")), ImmutableList.of(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")))), anonymousRow(BIGINT, createCharType(2))))))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java index 95b3b1568fb4..52718273d680 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -21,7 +21,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.SortOrder; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -85,7 +85,7 @@ public void testWindowNotNeeded() .on(p -> buildProjectedWindow(p, symbol -> inputSymbolNameSet.contains(symbol.getName()), alwaysTrue())) .matches( strictProject( - Maps.asMap(inputSymbolNameSet, symbol -> expression(new SymbolReference(BIGINT, symbol))), + Maps.asMap(inputSymbolNameSet, symbol -> expression(new Reference(BIGINT, symbol))), values(inputSymbolNameList))); } @@ -99,8 +99,8 @@ public void testOneFunctionNotNeeded() .matches( strictProject( ImmutableMap.of( - "output2", expression(new SymbolReference(BIGINT, "output2")), - "unused", expression(new SymbolReference(BIGINT, "unused"))), + "output2", expression(new Reference(BIGINT, "output2")), + "unused", expression(new Reference(BIGINT, "unused"))), window(windowBuilder -> windowBuilder .prePartitionedInputs(ImmutableSet.of()) .specification( @@ -113,7 +113,7 @@ public void testOneFunctionNotNeeded() strictProject( Maps.asMap( Sets.difference(inputSymbolNameSet, ImmutableSet.of("input1", "startValue1", "endValue1")), - symbol -> expression(new SymbolReference(BIGINT, symbol))), + symbol -> expression(new Reference(BIGINT, symbol))), values(inputSymbolNameList))))); } @@ -151,8 +151,8 @@ public void testUnusedInputNotNeeded() .matches( strictProject( ImmutableMap.of( - "output1", expression(new SymbolReference(BIGINT, "output1")), - "output2", expression(new SymbolReference(BIGINT, "output2"))), + "output1", expression(new Reference(BIGINT, "output1")), + "output2", expression(new Reference(BIGINT, "output2"))), window(windowBuilder -> windowBuilder .prePartitionedInputs(ImmutableSet.of()) .specification( @@ -166,7 +166,7 @@ public void testUnusedInputNotNeeded() strictProject( Maps.asMap( Sets.filter(inputSymbolNameSet, symbolName -> !symbolName.equals("unused")), - symbol -> expression(new SymbolReference(BIGINT, symbol))), + symbol -> expression(new Reference(BIGINT, symbol))), values(inputSymbolNameList))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java index fbcf18d9be39..10c543ba198b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -17,9 +17,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -67,12 +67,12 @@ public void testPushesAggregationThroughLeftJoin() Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) + .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1", DOUBLE)))) .matches( project(ImmutableMap.of( - "COL1", expression(new SymbolReference(DOUBLE, "COL1")), - "COALESCE", expression(new CoalesceExpression(new SymbolReference(DOUBLE, "AVG"), new SymbolReference(DOUBLE, "AVG_NULL")))), + "COL1", expression(new Reference(DOUBLE, "COL1")), + "COALESCE", expression(new Coalesce(new Reference(DOUBLE, "AVG"), new Reference(DOUBLE, "AVG_NULL")))), join(INNER, builder -> builder .left( join(LEFT, leftJoinBuilder -> leftJoinBuilder @@ -108,12 +108,12 @@ public void testPushesAggregationThroughRightJoin() Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) + .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1", DOUBLE)))) .matches( project(ImmutableMap.of( - "COALESCE", expression(new CoalesceExpression(new SymbolReference(DOUBLE, "AVG"), new SymbolReference(DOUBLE, "AVG_NULL"))), - "COL1", expression(new SymbolReference(DOUBLE, "COL1"))), + "COALESCE", expression(new Coalesce(new Reference(DOUBLE, "AVG"), new Reference(DOUBLE, "AVG_NULL"))), + "COL1", expression(new Reference(DOUBLE, "COL1"))), join(INNER, builder -> builder .left( join(RIGHT, leftJoinBuilder -> leftJoinBuilder @@ -153,14 +153,14 @@ public void testPushesAggregationWithMask() Optional.empty())) .addAggregation( p.symbol("AVG", DOUBLE), - PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), + PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE), p.symbol("MASK", BOOLEAN)) .singleGroupingSet(p.symbol("COL1", DOUBLE)))) .matches( project(ImmutableMap.of( - "COL1", expression(new SymbolReference(DOUBLE, "COL1")), - "COALESCE", expression(new CoalesceExpression(new SymbolReference(DOUBLE, "AVG"), new SymbolReference(DOUBLE, "AVG_NULL")))), + "COL1", expression(new Reference(DOUBLE, "COL1")), + "COALESCE", expression(new Coalesce(new Reference(DOUBLE, "AVG"), new Reference(DOUBLE, "AVG_NULL")))), join(INNER, builder -> builder .left( join(LEFT, leftJoinBuilder -> leftJoinBuilder @@ -206,8 +206,8 @@ public void testPushCountAllAggregation() .singleGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of( - "COL1", expression(new SymbolReference(DOUBLE, "COL1")), - "COALESCE", expression(new CoalesceExpression(new SymbolReference(BIGINT, "COUNT"), new SymbolReference(BIGINT, "COUNT_NULL")))), + "COL1", expression(new Reference(DOUBLE, "COL1")), + "COALESCE", expression(new Coalesce(new Reference(BIGINT, "COUNT"), new Reference(BIGINT, "COUNT_NULL")))), join(INNER, builder -> builder .left( join(LEFT, leftJoinBuilder -> leftJoinBuilder @@ -269,7 +269,7 @@ public void testDoesNotFireWhenNotDistinct() Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol(UNKNOWN, "AVG"), PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "AVG"), PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol(UNKNOWN, "COL1")))) .doesNotFire(); @@ -295,7 +295,7 @@ public void testDoesNotFireWhenNotDistinct() Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) + .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .doesNotFire(); } @@ -314,7 +314,7 @@ public void testDoesNotFireWhenGroupingOnInner() Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol(UNKNOWN, "AVG"), PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "AVG"), PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol(UNKNOWN, "COL1"), new Symbol(UNKNOWN, "COL3")))) .doesNotFire(); } @@ -334,7 +334,7 @@ public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol(UNKNOWN, "SUM"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(DOUBLE, "COL1"))), ImmutableList.of(DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "SUM"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(DOUBLE, "COL1"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol(UNKNOWN, "COL1")))) .doesNotFire(); } @@ -354,7 +354,7 @@ public void testDoesNotFireWhenAggregationOnMultipleSymbolsDoesNotHaveSomeSymbol Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol(UNKNOWN, "MIN_BY"), PlanBuilder.aggregation("min_by", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"), new SymbolReference(DOUBLE, "COL1"))), ImmutableList.of(DOUBLE, DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "MIN_BY"), PlanBuilder.aggregation("min_by", ImmutableList.of(new Reference(DOUBLE, "COL2"), new Reference(DOUBLE, "COL1"))), ImmutableList.of(DOUBLE, DOUBLE)) .singleGroupingSet(new Symbol(UNKNOWN, "COL1")))) .doesNotFire(); @@ -370,9 +370,9 @@ public void testDoesNotFireWhenAggregationOnMultipleSymbolsDoesNotHaveSomeSymbol Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol(UNKNOWN, "SUM"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) - .addAggregation(new Symbol(UNKNOWN, "MIN_BY"), PlanBuilder.aggregation("min_by", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"), new SymbolReference(DOUBLE, "COL3"))), ImmutableList.of(DOUBLE, DOUBLE)) - .addAggregation(new Symbol(UNKNOWN, "MAX_BY"), PlanBuilder.aggregation("max_by", ImmutableList.of(new SymbolReference(DOUBLE, "COL2"), new SymbolReference(DOUBLE, "COL1"))), ImmutableList.of(DOUBLE, DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "SUM"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(DOUBLE, "COL2"))), ImmutableList.of(DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "MIN_BY"), PlanBuilder.aggregation("min_by", ImmutableList.of(new Reference(DOUBLE, "COL2"), new Reference(DOUBLE, "COL3"))), ImmutableList.of(DOUBLE, DOUBLE)) + .addAggregation(new Symbol(UNKNOWN, "MAX_BY"), PlanBuilder.aggregation("max_by", ImmutableList.of(new Reference(DOUBLE, "COL2"), new Reference(DOUBLE, "COL1"))), ImmutableList.of(DOUBLE, DOUBLE)) .singleGroupingSet(new Symbol(UNKNOWN, "COL1")))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java index 5040105bbae7..c12102805f99 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java @@ -19,7 +19,7 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -65,8 +65,8 @@ public void test() // expression nested in another unrelated expression test( - new SubscriptExpression(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L))), anonymousRow(BIGINT)), new Constant(INTEGER, 1L)), - new SubscriptExpression(BIGINT, new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT))), new Constant(INTEGER, 1L))); + new Subscript(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L))), anonymousRow(BIGINT)), new Constant(INTEGER, 1L)), + new Subscript(BIGINT, new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT))), new Constant(INTEGER, 1L))); // don't insert CAST(x AS unknown) test( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java index a9152504cbd5..f07ea28fa7a5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -26,16 +26,16 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.PlanMatchPattern; @@ -58,11 +58,11 @@ import static io.trino.spi.type.RowType.rowType; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -108,7 +108,7 @@ public void testDoesNotFire() tester().assertThat(new PushDownDereferenceThroughFilter()) .on(p -> p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "x"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "x"), new Constant(BIGINT, 5L)), p.values(p.symbol("x")))) .doesNotFire(); @@ -117,12 +117,12 @@ public void testDoesNotFire() .on(p -> p.project( Assignments.of( - p.symbol("expr_1"), new SubscriptExpression(rowType(field("x", BIGINT), field("y", BIGINT)), new Cast(new Row(ImmutableList.of(new SymbolReference(ROW_TYPE, "a"), new SymbolReference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), new Constant(INTEGER, 1L)), - p.symbol("expr_2"), new SubscriptExpression(BIGINT, new SubscriptExpression(rowType(field("x", BIGINT), field("y", BIGINT)), new Cast(new Row(ImmutableList.of(new SymbolReference(ROW_TYPE, "a"), new SymbolReference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L))), + p.symbol("expr_1"), new Subscript(rowType(field("x", BIGINT), field("y", BIGINT)), new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), new Constant(INTEGER, 1L)), + p.symbol("expr_2"), new Subscript(BIGINT, new Subscript(rowType(field("x", BIGINT), field("y", BIGINT)), new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L))), p.project( Assignments.of( - p.symbol("a", ROW_TYPE), new SymbolReference(ROW_TYPE, "a"), - p.symbol("b"), new SymbolReference(BIGINT, "b")), + p.symbol("a", ROW_TYPE), new Reference(ROW_TYPE, "a"), + p.symbol("b"), new Reference(BIGINT, "b")), p.values(p.symbol("a", ROW_TYPE), p.symbol("b"))))) .doesNotFire(); @@ -130,9 +130,9 @@ public void testDoesNotFire() tester().assertThat(new PushDownDereferenceThroughProject()) .on(p -> p.project( - Assignments.of(p.symbol("expr", ROW_TYPE), new SymbolReference(BIGINT, "a"), p.symbol("a_x"), new SubscriptExpression(BIGINT, new SymbolReference(BIGINT, "a"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr", ROW_TYPE), new Reference(BIGINT, "a"), p.symbol("a_x"), new Subscript(BIGINT, new Reference(BIGINT, "a"), new Constant(INTEGER, 1L))), p.project( - Assignments.of(p.symbol("a", ROW_TYPE), new SymbolReference(ROW_TYPE, "a")), + Assignments.of(p.symbol("a", ROW_TYPE), new Reference(ROW_TYPE, "a")), p.values(p.symbol("a", ROW_TYPE))))) .doesNotFire(); } @@ -143,20 +143,20 @@ public void testPushdownDereferenceThroughProject() tester().assertThat(new PushDownDereferenceThroughProject()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), p.project( Assignments.of( - p.symbol("y"), new SymbolReference(BIGINT, "y"), - p.symbol("msg", ROW_TYPE), new SymbolReference(ROW_TYPE, "msg")), + p.symbol("y"), new Reference(BIGINT, "y"), + p.symbol("msg", ROW_TYPE), new Reference(ROW_TYPE, "msg")), p.values(p.symbol("msg", ROW_TYPE), p.symbol("y"))))) .matches( strictProject( - ImmutableMap.of("x", expression(new SymbolReference(BIGINT, "msg_x"))), + ImmutableMap.of("x", expression(new Reference(BIGINT, "msg_x"))), strictProject( ImmutableMap.of( - "msg_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), - "y", expression(new SymbolReference(BIGINT, "y")), - "msg", expression(new SymbolReference(BIGINT, "msg"))), + "msg_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + "y", expression(new Reference(BIGINT, "y")), + "msg", expression(new Reference(BIGINT, "msg"))), values("msg", "y")))); } @@ -167,9 +167,9 @@ public void testPushDownDereferenceThroughJoin() .on(p -> p.project( Assignments.builder() - .put(p.symbol("left_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("right_y"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) - .put(p.symbol("z"), new SymbolReference(BIGINT, "z")) + .put(p.symbol("left_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("right_y"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) + .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.join(INNER, p.values(p.symbol("msg1", ROW_TYPE), p.symbol("unreferenced_symbol")), @@ -177,24 +177,24 @@ public void testPushDownDereferenceThroughJoin() .matches( strictProject( ImmutableMap.builder() - .put("left_x", expression(new SymbolReference(BIGINT, "x"))) - .put("right_y", expression(new SymbolReference(BIGINT, "y"))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) + .put("left_x", expression(new Reference(BIGINT, "x"))) + .put("right_y", expression(new Reference(BIGINT, "y"))) + .put("z", expression(new Reference(BIGINT, "z"))) .buildOrThrow(), join(INNER, builder -> builder .left( strictProject( ImmutableMap.of( - "x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), - "msg1", expression(new SymbolReference(ROW_TYPE, "msg1")), - "unreferenced_symbol", expression(new SymbolReference(BIGINT, "unreferenced_symbol"))), + "x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), + "msg1", expression(new Reference(ROW_TYPE, "msg1")), + "unreferenced_symbol", expression(new Reference(BIGINT, "unreferenced_symbol"))), values("msg1", "unreferenced_symbol"))) .right( strictProject( ImmutableMap.builder() - .put("y", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) + .put("y", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) + .put("z", expression(new Reference(BIGINT, "z"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), values("msg2", "z")))))); @@ -203,24 +203,24 @@ public void testPushDownDereferenceThroughJoin() .on(p -> p.project( Assignments.of( - p.symbol("expr"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)), - p.symbol("expr_2"), new SymbolReference(ROW_TYPE, "msg2")), + p.symbol("expr"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)), + p.symbol("expr_2"), new Reference(ROW_TYPE, "msg2")), p.join(INNER, p.values(p.symbol("msg1", ROW_TYPE)), p.values(p.symbol("msg2", ROW_TYPE)), - new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))), new Constant(BIGINT, 10L))))) + new Comparison(GREATER_THAN, new Arithmetic(ADD_INTEGER, ADD, new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))), new Constant(BIGINT, 10L))))) .matches( project( ImmutableMap.of( - "expr", expression(new SymbolReference(BIGINT, "msg1_x")), - "expr_2", expression(new SymbolReference(ROW_TYPE, "msg2"))), + "expr", expression(new Reference(BIGINT, "msg1_x")), + "expr_2", expression(new Reference(ROW_TYPE, "msg2"))), join(INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(BIGINT, "msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))), new Constant(BIGINT, 10L))) + .filter(new Comparison(GREATER_THAN, new Arithmetic(ADD_INTEGER, ADD, new Reference(BIGINT, "msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))), new Constant(BIGINT, 10L))) .left( strictProject( ImmutableMap.of( - "msg1_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), - "msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))), + "msg1_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), + "msg1", expression(new Reference(ROW_TYPE, "msg1"))), values("msg1"))) .right(values("msg2"))))); } @@ -232,8 +232,8 @@ public void testPushdownDereferencesThroughSemiJoin() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) .build(), p.semiJoin( p.symbol("msg2", ROW_TYPE), @@ -246,8 +246,8 @@ public void testPushdownDereferencesThroughSemiJoin() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", PlanMatchPattern.expression(new SymbolReference(BIGINT, "expr"))) - .put("msg2_x", PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) // Not pushed down because msg2 is sourceJoinSymbol + .put("msg1_x", PlanMatchPattern.expression(new Reference(BIGINT, "expr"))) + .put("msg2_x", PlanMatchPattern.expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) // Not pushed down because msg2 is sourceJoinSymbol .buildOrThrow(), semiJoin( "msg2", @@ -255,9 +255,9 @@ public void testPushdownDereferencesThroughSemiJoin() "match", strictProject( ImmutableMap.of( - "expr", PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), - "msg1", PlanMatchPattern.expression(new SymbolReference(ROW_TYPE, "msg1")), - "msg2", PlanMatchPattern.expression(new SymbolReference(ROW_TYPE, "msg2"))), + "expr", PlanMatchPattern.expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), + "msg1", PlanMatchPattern.expression(new Reference(ROW_TYPE, "msg1")), + "msg2", PlanMatchPattern.expression(new Reference(ROW_TYPE, "msg2"))), values("msg1", "msg2")), values("filtering_msg")))); } @@ -269,7 +269,7 @@ public void testPushdownDereferencesThroughUnnest() tester().assertThat(new PushDownDereferenceThroughUnnest()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), p.unnest( ImmutableList.of(p.symbol("msg", ROW_TYPE)), ImmutableList.of(new UnnestNode.Mapping(p.symbol("arr", arrayType), ImmutableList.of(p.symbol("field")))), @@ -278,13 +278,13 @@ public void testPushdownDereferencesThroughUnnest() p.values(p.symbol("msg", ROW_TYPE), p.symbol("arr", arrayType))))) .matches( strictProject( - ImmutableMap.of("x", expression(new SymbolReference(BIGINT, "msg_x"))), + ImmutableMap.of("x", expression(new Reference(BIGINT, "msg_x"))), unnest( strictProject( ImmutableMap.of( - "msg_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), - "msg", expression(new SymbolReference(ROW_TYPE, "msg")), - "arr", expression(new SymbolReference(arrayType, "arr"))), + "msg_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + "msg", expression(new Reference(ROW_TYPE, "msg")), + "arr", expression(new Reference(arrayType, "arr"))), values("msg", "arr"))))); // Test with dereferences on unnested column @@ -295,8 +295,8 @@ public void testPushdownDereferencesThroughUnnest() .on(p -> p.project( Assignments.of( - p.symbol("deref_replicate", BIGINT), new SubscriptExpression(BIGINT, new SymbolReference(rowType, "replicate"), new Constant(INTEGER, 2L)), - p.symbol("deref_unnest", BIGINT), new SubscriptExpression(BIGINT, new SymbolReference(nestedColumnType, "unnested_row"), new Constant(INTEGER, 2L))), + p.symbol("deref_replicate", BIGINT), new Subscript(BIGINT, new Reference(rowType, "replicate"), new Constant(INTEGER, 2L)), + p.symbol("deref_unnest", BIGINT), new Subscript(BIGINT, new Reference(nestedColumnType, "unnested_row"), new Constant(INTEGER, 2L))), p.unnest( ImmutableList.of(p.symbol("replicate", rowType)), ImmutableList.of( @@ -307,16 +307,16 @@ public void testPushdownDereferencesThroughUnnest() .matches( strictProject( ImmutableMap.of( - "deref_replicate", expression(new SymbolReference(BIGINT, "symbol")), - "deref_unnest", expression(new SubscriptExpression(rowType, new SymbolReference(nestedColumnType, "unnested_row"), new Constant(INTEGER, 2L)))), // not pushed down + "deref_replicate", expression(new Reference(BIGINT, "symbol")), + "deref_unnest", expression(new Subscript(rowType, new Reference(nestedColumnType, "unnested_row"), new Constant(INTEGER, 2L)))), // not pushed down unnest( ImmutableList.of("replicate", "symbol"), ImmutableList.of(unnestMapping("nested", ImmutableList.of("unnested_bigint", "unnested_row"))), strictProject( ImmutableMap.of( - "symbol", expression(new SubscriptExpression(BIGINT, new SymbolReference(rowType, "replicate"), new Constant(INTEGER, 2L))), - "replicate", expression(new SymbolReference(rowType, "replicate")), - "nested", expression(new SymbolReference(nestedColumnType, "nested"))), + "symbol", expression(new Subscript(BIGINT, new Reference(rowType, "replicate"), new Constant(INTEGER, 2L))), + "replicate", expression(new Reference(rowType, "replicate")), + "nested", expression(new Reference(nestedColumnType, "nested"))), values("replicate", "nested"))))); } @@ -332,10 +332,10 @@ public void testExtractDereferencesFromFilterAboveScan() tester().assertThat(new ExtractDereferencesFromFilterAboveScan()) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(NOT_EQUAL, new SubscriptExpression(BIGINT, new SubscriptExpression(ROW_TYPE, new SymbolReference(nestedRowType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new Constant(INTEGER, 5L)), - new ComparisonExpression(EQUAL, new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "b"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 2L)), - new NotExpression(new IsNullPredicate(new Cast(new SubscriptExpression(ROW_TYPE, new SymbolReference(nestedRowType, "a"), new Constant(INTEGER, 1L)), JSON))))), + new Logical(AND, ImmutableList.of( + new Comparison(NOT_EQUAL, new Subscript(BIGINT, new Subscript(ROW_TYPE, new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Subscript(BIGINT, new Reference(ROW_TYPE, "b"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 2L)), + new Not(new IsNull(new Cast(new Subscript(ROW_TYPE, new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L)), JSON))))), p.tableScan( testTable, ImmutableList.of(p.symbol("a", nestedRowType), p.symbol("b", ROW_TYPE)), @@ -344,14 +344,14 @@ public void testExtractDereferencesFromFilterAboveScan() p.symbol("b", ROW_TYPE), new TpchColumnHandle("b", ROW_TYPE))))) .matches(project( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 5L)), new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "expr_0"), new Constant(INTEGER, 2L)), new NotExpression(new IsNullPredicate(new Cast(new SymbolReference(ROW_TYPE, "expr_1"), JSON))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "expr"), new Constant(INTEGER, 5L)), new Comparison(EQUAL, new Reference(INTEGER, "expr_0"), new Constant(INTEGER, 2L)), new Not(new IsNull(new Cast(new Reference(ROW_TYPE, "expr_1"), JSON))))), strictProject( ImmutableMap.of( - "expr", PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SubscriptExpression(ROW_TYPE, new SymbolReference(nestedRowType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), - "expr_0", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "b"), new Constant(INTEGER, 2L))), - "expr_1", expression(new SubscriptExpression(ROW_TYPE, new SymbolReference(nestedRowType, "a"), new Constant(INTEGER, 1L))), - "a", expression(new SymbolReference(nestedRowType, "a")), - "b", expression(new SymbolReference(ROW_TYPE, "b"))), + "expr", PlanMatchPattern.expression(new Subscript(BIGINT, new Subscript(ROW_TYPE, new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), + "expr_0", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "b"), new Constant(INTEGER, 2L))), + "expr_1", expression(new Subscript(ROW_TYPE, new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L))), + "a", expression(new Reference(nestedRowType, "a")), + "b", expression(new Reference(ROW_TYPE, "b"))), tableScan( testTable.getConnectorHandle()::equals, TupleDomain.all(), @@ -367,23 +367,23 @@ public void testPushdownDereferenceThroughFilter() .on(p -> p.project( Assignments.of( - p.symbol("expr", BIGINT), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)), - p.symbol("expr_2", BIGINT), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))), + p.symbol("expr", BIGINT), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)), + p.symbol("expr_2", BIGINT), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))), p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)), new Constant(createVarcharType(3), Slices.utf8Slice("foo"))), new NotExpression(new IsNullPredicate(new SymbolReference(ROW_TYPE, "msg2"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)), new Constant(createVarcharType(3), Slices.utf8Slice("foo"))), new Not(new IsNull(new Reference(ROW_TYPE, "msg2"))))), p.values(p.symbol("msg", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) .matches( strictProject( ImmutableMap.of( - "expr", expression(new SymbolReference(BIGINT, "msg_x")), - "expr_2", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))), // not pushed down since predicate contains msg2 reference + "expr", expression(new Reference(BIGINT, "msg_x")), + "expr_2", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))), // not pushed down since predicate contains msg2 reference filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(createVarcharType(3), "msg_x"), new Constant(createVarcharType(3), Slices.utf8Slice("foo"))), new NotExpression(new IsNullPredicate(new SymbolReference(ROW_TYPE, "msg2"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(createVarcharType(3), "msg_x"), new Constant(createVarcharType(3), Slices.utf8Slice("foo"))), new Not(new IsNull(new Reference(ROW_TYPE, "msg2"))))), strictProject( ImmutableMap.of( - "msg_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), - "msg", expression(new SymbolReference(ROW_TYPE, "msg")), - "msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))), + "msg_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + "msg", expression(new Reference(ROW_TYPE, "msg")), + "msg2", expression(new Reference(ROW_TYPE, "msg2"))), values("msg", "msg2"))))); } @@ -394,9 +394,9 @@ public void testPushDownDereferenceThroughLimit() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_y"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) - .put(p.symbol("z"), new SymbolReference(BIGINT, "z")) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_y"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) + .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.limit(10, ImmutableList.of(p.symbol("msg2", ROW_TYPE)), @@ -404,19 +404,19 @@ public void testPushDownDereferenceThroughLimit() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new SymbolReference(BIGINT, "x"))) - .put("msg2_y", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) + .put("msg1_x", expression(new Reference(BIGINT, "x"))) + .put("msg2_y", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) + .put("z", expression(new Reference(BIGINT, "z"))) .buildOrThrow(), limit( 10, ImmutableList.of(sort("msg2", ASCENDING, FIRST)), strictProject( ImmutableMap.builder() - .put("x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) + .put("x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("z", expression(new Reference(BIGINT, "z"))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), values("msg1", "msg2", "z"))))); } @@ -427,9 +427,9 @@ public void testPushDownDereferenceThroughLimitWithPreSortedInputs() tester().assertThat(new PushDownDereferencesThroughLimit()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_y"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) - .put(p.symbol("z"), new SymbolReference(BIGINT, "z")) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_y"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) + .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.limit( 10, @@ -439,9 +439,9 @@ public void testPushDownDereferenceThroughLimitWithPreSortedInputs() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new SymbolReference(BIGINT, "x"))) - .put("msg2_y", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) + .put("msg1_x", expression(new Reference(BIGINT, "x"))) + .put("msg2_y", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) + .put("z", expression(new Reference(BIGINT, "z"))) .buildOrThrow(), limit( 10, @@ -450,10 +450,10 @@ public void testPushDownDereferenceThroughLimitWithPreSortedInputs() ImmutableList.of("msg2"), strictProject( ImmutableMap.builder() - .put("x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) + .put("x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("z", expression(new Reference(BIGINT, "z"))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), values("msg1", "msg2", "z"))))); } @@ -466,9 +466,9 @@ public void testPushDownDereferenceThroughSort() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg_y"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 2L))) - .put(p.symbol("z"), new SymbolReference(BIGINT, "z")) + .put(p.symbol("msg_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg_y"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 2L))) + .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.sort( ImmutableList.of(p.symbol("z"), p.symbol("msg", ROW_TYPE)), @@ -479,8 +479,8 @@ public void testPushDownDereferenceThroughSort() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))) - .put(p.symbol("z"), new SymbolReference(BIGINT, "z")) + .put(p.symbol("msg_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))) + .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.sort( ImmutableList.of(p.symbol("z")), @@ -488,15 +488,15 @@ public void testPushDownDereferenceThroughSort() .matches( strictProject( ImmutableMap.builder() - .put("msg_x", expression(new SymbolReference(BIGINT, "x"))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) + .put("msg_x", expression(new Reference(BIGINT, "x"))) + .put("z", expression(new Reference(BIGINT, "z"))) .buildOrThrow(), sort(ImmutableList.of(sort("z", ASCENDING, FIRST)), strictProject( ImmutableMap.builder() - .put("x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)))) - .put("z", expression(new SymbolReference(BIGINT, "z"))) - .put("msg", expression(new SymbolReference(ROW_TYPE, "msg"))) + .put("x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)))) + .put("z", expression(new Reference(BIGINT, "z"))) + .put("msg", expression(new Reference(ROW_TYPE, "msg"))) .buildOrThrow(), values("msg", "z"))))); } @@ -508,8 +508,8 @@ public void testPushdownDereferenceThroughRowNumber() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) .build(), p.rowNumber( ImmutableList.of(p.symbol("msg1", ROW_TYPE)), @@ -519,17 +519,17 @@ public void testPushdownDereferenceThroughRowNumber() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) - .put("msg2_x", expression(new SymbolReference(BIGINT, "expr"))) + .put("msg1_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg2_x", expression(new Reference(BIGINT, "expr"))) .buildOrThrow(), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of("msg1")), strictProject( ImmutableMap.builder() - .put("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) + .put("expr", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), values("msg1", "msg2"))))); } @@ -541,9 +541,9 @@ public void testPushdownDereferenceThroughTopNRanking() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg3_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg3_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L))) .build(), p.topNRanking( new DataOrganizationSpecification( @@ -559,18 +559,18 @@ public void testPushdownDereferenceThroughTopNRanking() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) - .put("msg2_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) - .put("msg3_x", expression(new SymbolReference(BIGINT, "expr"))) + .put("msg1_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg2_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) + .put("msg3_x", expression(new Reference(BIGINT, "expr"))) .buildOrThrow(), topNRanking( pattern -> pattern.specification(singletonList("msg1"), singletonList("msg2"), ImmutableMap.of("msg2", ASC_NULLS_FIRST)), strictProject( ImmutableMap.builder() - .put("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L)))) - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) - .put("msg3", expression(new SymbolReference(ROW_TYPE, "msg3"))) + .put("expr", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L)))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) + .put("msg3", expression(new Reference(ROW_TYPE, "msg3"))) .buildOrThrow(), values("msg1", "msg2", "msg3"))))); } @@ -582,23 +582,23 @@ public void testPushdownDereferenceThroughTopN() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) .build(), p.topN(5, ImmutableList.of(p.symbol("msg1", ROW_TYPE)), p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) - .put("msg2_x", expression(new SymbolReference(BIGINT, "expr"))) + .put("msg1_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg2_x", expression(new Reference(BIGINT, "expr"))) .buildOrThrow(), topN(5, ImmutableList.of(sort("msg1", ASCENDING, FIRST)), strictProject( ImmutableMap.builder() - .put("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) + .put("expr", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), values("msg1", "msg2"))))); } @@ -610,11 +610,11 @@ public void testPushdownDereferenceThroughWindow() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg3_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg4_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg4"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg5_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg5"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg3_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg4_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg4"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg5_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg5"), new Constant(INTEGER, 1L))) .build(), p.window( new DataOrganizationSpecification( @@ -646,11 +646,11 @@ public void testPushdownDereferenceThroughWindow() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) // not pushed down because used in partitionBy - .put("msg2_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) // not pushed down because used in orderBy - .put("msg3_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L)))) // not pushed down because the whole column is used in windowNode function - .put("msg4_x", expression(new SymbolReference(BIGINT, "expr"))) // pushed down because msg4[1] is being used in the function - .put("msg5_x", expression(new SymbolReference(BIGINT, "expr2"))) // pushed down because not referenced in windowNode + .put("msg1_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) // not pushed down because used in partitionBy + .put("msg2_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) // not pushed down because used in orderBy + .put("msg3_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L)))) // not pushed down because the whole column is used in windowNode function + .put("msg4_x", expression(new Reference(BIGINT, "expr"))) // pushed down because msg4[1] is being used in the function + .put("msg5_x", expression(new Reference(BIGINT, "expr2"))) // pushed down because not referenced in windowNode .buildOrThrow(), window( windowMatcherBuilder -> windowMatcherBuilder @@ -658,13 +658,13 @@ public void testPushdownDereferenceThroughWindow() .addFunction(windowFunction("min", singletonList("msg3"), DEFAULT_FRAME)), strictProject( ImmutableMap.builder() - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) - .put("msg3", expression(new SymbolReference(ROW_TYPE, "msg3"))) - .put("msg4", expression(new SymbolReference(ROW_TYPE, "msg4"))) - .put("msg5", expression(new SymbolReference(ROW_TYPE, "msg5"))) - .put("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg4"), new Constant(INTEGER, 1L)))) - .put("expr2", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg5"), new Constant(INTEGER, 1L)))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) + .put("msg3", expression(new Reference(ROW_TYPE, "msg3"))) + .put("msg4", expression(new Reference(ROW_TYPE, "msg4"))) + .put("msg5", expression(new Reference(ROW_TYPE, "msg5"))) + .put("expr", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg4"), new Constant(INTEGER, 1L)))) + .put("expr2", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg5"), new Constant(INTEGER, 1L)))) .buildOrThrow(), values("msg1", "msg2", "msg3", "msg4", "msg5"))))); } @@ -676,20 +676,20 @@ public void testPushdownDereferenceThroughAssignUniqueId() .on(p -> p.project( Assignments.builder() - .put(p.symbol("expr"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("expr"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) .build(), p.assignUniqueId( p.symbol("unique"), p.values(p.symbol("msg1", ROW_TYPE))))) .matches( strictProject( - ImmutableMap.of("expr", expression(new SymbolReference(BIGINT, "msg1_x"))), + ImmutableMap.of("expr", expression(new Reference(BIGINT, "msg1_x"))), assignUniqueId( "unique", strictProject( ImmutableMap.builder() - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg1_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg1_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) .buildOrThrow(), values("msg1"))))); } @@ -701,8 +701,8 @@ public void testPushdownDereferenceThroughMarkDistinct() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg2_x"), new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) .build(), p.markDistinct( p.symbol("is_distinct", BOOLEAN), @@ -711,16 +711,16 @@ public void testPushdownDereferenceThroughMarkDistinct() .matches( strictProject( ImmutableMap.of( - "msg1_x", expression(new SymbolReference(BIGINT, "expr")), // pushed down - "msg2_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))), // not pushed down because used in markDistinct + "msg1_x", expression(new Reference(BIGINT, "expr")), // pushed down + "msg2_x", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))), // not pushed down because used in markDistinct markDistinct( "is_distinct", singletonList("msg2"), strictProject( ImmutableMap.builder() - .put("msg1", expression(new SymbolReference(ROW_TYPE, "msg1"))) - .put("msg2", expression(new SymbolReference(ROW_TYPE, "msg2"))) - .put("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) + .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) + .put("expr", expression(new Subscript(BIGINT, new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) .buildOrThrow(), values("msg1", "msg2"))))); } @@ -733,23 +733,23 @@ public void testMultiLevelPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_1"), new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "a"), new Constant(INTEGER, 1L)), - p.symbol("expr_2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SubscriptExpression(BIGINT, new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new SubscriptExpression(BIGINT, new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), new SubscriptExpression(BIGINT, new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), + p.symbol("expr_1"), new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "a"), new Constant(INTEGER, 1L)), + p.symbol("expr_2"), new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(ADD_INTEGER, ADD, new Subscript(BIGINT, new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new Subscript(BIGINT, new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), new Subscript(BIGINT, new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), p.project( Assignments.identity(ImmutableList.of(p.symbol("a", complexType), p.symbol("b", complexType))), p.values(p.symbol("a", complexType), p.symbol("b", complexType))))) .matches( strictProject( ImmutableMap.of( - "expr_1", expression(new SymbolReference(complexType.getFields().get(0).getType(), "a_f1")), - "expr_2", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SubscriptExpression(BIGINT, new SymbolReference(complexType.getFields().get(0).getType(), "a_f1"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new SymbolReference(BIGINT, "b_f1_f1")), new SymbolReference(BIGINT, "b_f1_f2")))), + "expr_1", expression(new Reference(complexType.getFields().get(0).getType(), "a_f1")), + "expr_2", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(ADD_INTEGER, ADD, new Arithmetic(ADD_INTEGER, ADD, new Subscript(BIGINT, new Reference(complexType.getFields().get(0).getType(), "a_f1"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new Reference(BIGINT, "b_f1_f1")), new Reference(BIGINT, "b_f1_f2")))), strictProject( ImmutableMap.of( - "a", expression(new SymbolReference(complexType, "a")), - "b", expression(new SymbolReference(complexType, "b")), - "a_f1", expression(new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "a"), new Constant(INTEGER, 1L))), - "b_f1_f1", PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), - "b_f1_f2", PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SubscriptExpression(rowType(field("f1", BIGINT), field("f2", BIGINT)), new SymbolReference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), + "a", expression(new Reference(complexType, "a")), + "b", expression(new Reference(complexType, "b")), + "a_f1", expression(new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "a"), new Constant(INTEGER, 1L))), + "b_f1_f1", PlanMatchPattern.expression(new Subscript(BIGINT, new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), + "b_f1_f2", PlanMatchPattern.expression(new Subscript(BIGINT, new Subscript(rowType(field("f1", BIGINT), field("f2", BIGINT)), new Reference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), values("a", "b")))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java index fd9b28004222..93d92281bff3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java @@ -20,11 +20,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -43,11 +43,11 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.patternRecognition; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -69,7 +69,7 @@ public void testNoAggregations() tester().assertThat(new PushDownProjectionsFromPatternRecognition()) .on(p -> p.patternRecognition(builder -> builder .pattern(new IrLabel("X")) - .addVariableDefinition(new IrLabel("X"), TRUE_LITERAL) + .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.values(p.symbol("a"))))) .doesNotFire(); } @@ -82,9 +82,9 @@ public void testDoNotPushRuntimeEvaluatedArguments() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new FunctionCall(MAX_BY, ImmutableList.of( - new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "match")), - new FunctionCall(CONCAT, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("x")), new SymbolReference(VARCHAR, "classifier"))))), + new Comparison(GREATER_THAN, new Call(MAX_BY, ImmutableList.of( + new Arithmetic(ADD_INTEGER, ADD, new Constant(INTEGER, 1L), new Reference(INTEGER, "match")), + new Call(CONCAT, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("x")), new Reference(VARCHAR, "classifier"))))), new Constant(INTEGER, 5L)), ImmutableMap.of( new Symbol(VARCHAR, "classifier"), new ClassifierValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0)), @@ -101,7 +101,7 @@ public void testDoNotPushSymbolReferences() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(GREATER_THAN, new FunctionCall(MAX_BY, ImmutableList.of(new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), new Constant(INTEGER, 5L))) + new Comparison(GREATER_THAN, new Call(MAX_BY, ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), new Constant(INTEGER, 5L))) .source(p.values(p.symbol("a"), p.symbol("b"))))) .doesNotFire(); } @@ -115,11 +115,11 @@ public void testPreProjectArguments() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "agg"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "agg"), new Constant(BIGINT, 5L)), ImmutableMap.of(new Symbol(BIGINT, "agg"), new AggregationValuePointer( maxBy, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 2L))), + ImmutableList.of(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 2L))), Optional.empty(), Optional.empty()))) .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) @@ -128,19 +128,19 @@ public void testPreProjectArguments() .pattern(new IrLabel("X")) .addVariableDefinition( new IrLabel("X"), - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "agg"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "agg"), new Constant(BIGINT, 5L)), ImmutableMap.of("agg", new AggregationValuePointer( maxBy, new AggregatedSetDescriptor(ImmutableSet.of(), true), - ImmutableList.of(new SymbolReference(BIGINT, "expr_1"), new SymbolReference(BIGINT, "expr_2")), + ImmutableList.of(new Reference(BIGINT, "expr_1"), new Reference(BIGINT, "expr_2")), Optional.empty(), Optional.empty()))), project( ImmutableMap.of( - "expr_1", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))), - "expr_2", PlanMatchPattern.expression(new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 2L))), - "a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), - "b", PlanMatchPattern.expression(new SymbolReference(BIGINT, "b"))), + "expr_1", PlanMatchPattern.expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))), + "expr_2", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 2L))), + "a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), + "b", PlanMatchPattern.expression(new Reference(BIGINT, "b"))), values("a", "b")))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java index 06cd282cc3e3..81174ef36c21 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.java @@ -18,11 +18,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject; import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject; @@ -34,14 +34,14 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -64,7 +64,7 @@ public void testDoesNotFireWithNonGroupedAggregation() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count"), new Constant(INTEGER, 0L)), p.aggregation(builder -> builder .globalGrouping() .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -83,11 +83,11 @@ public void testDoesNotFireWithMultipleAggregations() Symbol count = p.symbol("count"); Symbol avg = p.symbol("avg"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count"), new Constant(INTEGER, 0L)), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) - .addAggregation(avg, PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(BIGINT, "g"))), ImmutableList.of(BIGINT), mask) + .addAggregation(avg, PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(BIGINT, "g"))), ImmutableList.of(BIGINT), mask) .source(p.values(g, mask)))); }) .doesNotFire(); @@ -101,7 +101,7 @@ public void testDoesNotFireWithNoAggregations() Symbol g = p.symbol("g"); Symbol mask = p.symbol("mask"); return p.filter( - TRUE_LITERAL, + TRUE, p.aggregation(builder -> builder .singleGroupingSet(g) .source(p.values(g, mask)))); @@ -117,7 +117,7 @@ public void testDoesNotFireWithNoMask() Symbol g = p.symbol("g"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count"), new Constant(INTEGER, 0L)), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) @@ -135,10 +135,10 @@ public void testDoesNotFireWithNoCountAggregation() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "count"), new Constant(INTEGER, 0L)), p.aggregation(builder -> builder .singleGroupingSet(g) - .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "g"))), ImmutableList.of(BIGINT), mask) + .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "g"))), ImmutableList.of(BIGINT), mask) .source(p.values(g, mask)))); }) .doesNotFire(); @@ -149,10 +149,10 @@ public void testDoesNotFireWithNoCountAggregation() Symbol mask = p.symbol("mask"); Symbol avg = p.symbol("avg"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "avg"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "avg"), new Constant(INTEGER, 0L)), p.aggregation(builder -> builder .singleGroupingSet(g) - .addAggregation(avg, PlanBuilder.aggregation("avg", ImmutableList.of(new SymbolReference(BIGINT, "g"))), ImmutableList.of(BIGINT), mask) + .addAggregation(avg, PlanBuilder.aggregation("avg", ImmutableList.of(new Reference(BIGINT, "g"))), ImmutableList.of(BIGINT), mask) .source(p.values(g, mask)))); }) .doesNotFire(); @@ -167,7 +167,7 @@ public void testFilterPredicateFalse() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)))), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -186,7 +186,7 @@ public void testDoesNotFireWhenFilterPredicateTrue() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - TRUE_LITERAL, + TRUE, p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -204,7 +204,7 @@ public void testDoesNotFireWhenFilterPredicateSatisfiedByAllCountValues() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)))), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "g"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)))), new Comparison(EQUAL, new Reference(BIGINT, "g"), new Constant(BIGINT, 5L)))), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -222,7 +222,7 @@ public void testPushDownMaskAndRemoveFilter() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -232,7 +232,7 @@ public void testPushDownMaskAndRemoveFilter() aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask")))); } @@ -245,7 +245,7 @@ public void testPushDownMaskAndSimplifyFilter() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "g"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "g"), new Constant(BIGINT, 5L)))), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -253,11 +253,11 @@ public void testPushDownMaskAndSimplifyFilter() }) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "g"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "g"), new Constant(BIGINT, 5L)), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask"))))); tester().assertThat(new PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())) @@ -266,7 +266,7 @@ public void testPushDownMaskAndSimplifyFilter() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "count"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)))), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -274,11 +274,11 @@ public void testPushDownMaskAndSimplifyFilter() }) .matches( filter( - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "count"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask"))))); } @@ -291,7 +291,7 @@ public void testPushDownMaskAndRetainFilter() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 5L)), p.aggregation(builder -> builder .singleGroupingSet(g) .addAggregation(count, PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of(), mask) @@ -299,11 +299,11 @@ public void testPushDownMaskAndRetainFilter() }) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 5L)), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask"))))); } @@ -316,7 +316,7 @@ public void testWithProject() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), p.project( Assignments.identity(count), p.aggregation(builder -> builder @@ -326,11 +326,11 @@ public void testWithProject() }) .matches( project( - ImmutableMap.of("count", expression(new SymbolReference(BIGINT, "count"))), + ImmutableMap.of("count", expression(new Reference(BIGINT, "count"))), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask"))))); tester().assertThat(new PushFilterThroughCountAggregationWithProject(tester().getPlannerContext())) @@ -339,7 +339,7 @@ public void testWithProject() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "g"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "g"), new Constant(BIGINT, 5L)))), p.project( Assignments.identity(count, g), p.aggregation(builder -> builder @@ -349,13 +349,13 @@ public void testWithProject() }) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "g"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "g"), new Constant(BIGINT, 5L)), project( - ImmutableMap.of("count", expression(new SymbolReference(BIGINT, "count")), "g", expression(new SymbolReference(BIGINT, "g"))), + ImmutableMap.of("count", expression(new Reference(BIGINT, "count")), "g", expression(new Reference(BIGINT, "g"))), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask")))))); tester().assertThat(new PushFilterThroughCountAggregationWithProject(tester().getPlannerContext())) @@ -364,7 +364,7 @@ public void testWithProject() Symbol mask = p.symbol("mask"); Symbol count = p.symbol("count"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 5L)), p.project( Assignments.identity(count), p.aggregation(builder -> builder @@ -374,13 +374,13 @@ public void testWithProject() }) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "count"), new Constant(BIGINT, 5L)), project( - ImmutableMap.of("count", expression(new SymbolReference(BIGINT, "count"))), + ImmutableMap.of("count", expression(new Reference(BIGINT, "count"))), aggregation( ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of())), filter( - new SymbolReference(BOOLEAN, "mask"), + new Reference(BOOLEAN, "mask"), values("g", "mask")))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java index 366abe8516f6..17369236c89b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java @@ -18,25 +18,25 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.and; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.and; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -89,10 +89,10 @@ public void testJoinFilterExpressionPushedDownToRightJoinSource() }) .matches( join(INNER, builder -> builder - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "expr"), new SymbolReference(BIGINT, "a"))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "a"))) .left(values("a")) .right(project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), values("b"))))); } @@ -113,13 +113,13 @@ public void testManyJoinFilterExpressionsPushedDownToRightJoinSource() }) .matches( join(INNER, builder -> builder - .filter(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "expr_less"), new SymbolReference(BIGINT, "a")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "expr_greater"), new SymbolReference(BIGINT, "a"))))) + .filter(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "expr_less"), new Reference(BIGINT, "a")), new Comparison(GREATER_THAN, new Reference(BIGINT, "expr_greater"), new Reference(BIGINT, "a"))))) .left(values("a")) .right( project( ImmutableMap.of( - "expr_less", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 1L))), - "expr_greater", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 10L)))), + "expr_less", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))), + "expr_greater", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 10L)))), values("b"))))); } @@ -138,11 +138,11 @@ public void testOnlyRightJoinFilterExpressionPushedDownToRightJoinSource() }) .matches( join(INNER, builder -> builder - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "expr"), new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 2L)))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Constant(BIGINT, 2L)))) .left(values("a")) .right( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), values("b"))))); } @@ -180,13 +180,13 @@ public void testParentFilterExpressionPushedDownToRightJoinSource() .matches( project( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "expr"), new SymbolReference(BIGINT, "a")), + new Comparison(LESS_THAN, new Reference(BIGINT, "expr"), new Reference(BIGINT, "a")), join(INNER, builder -> builder .left( values("a")) .right( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), values("b"))))))); } @@ -208,14 +208,14 @@ public void testManyParentFilterExpressionsPushedDownToRightJoinSource() }) .matches( project( - filter(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "expr_less"), new SymbolReference(BIGINT, "a")), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "expr_greater"), new SymbolReference(BIGINT, "a")))), + filter(new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "expr_less"), new Reference(BIGINT, "a")), new Comparison(GREATER_THAN, new Reference(BIGINT, "expr_greater"), new Reference(BIGINT, "a")))), join(INNER, builder -> builder .left(values("a")) .right( project( ImmutableMap.of( - "expr_less", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 1L))), - "expr_greater", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 10L)))), + "expr_less", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L))), + "expr_greater", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 10L)))), values("b"))))))); } @@ -237,15 +237,15 @@ public void testOnlyParentFilterExpressionExposedInaJoin() .matches( project( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "parent_expression"), new SymbolReference(BIGINT, "a")), + new Comparison(LESS_THAN, new Reference(BIGINT, "parent_expression"), new Reference(BIGINT, "a")), join(INNER, builder -> builder - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "join_expression"), new SymbolReference(BIGINT, "a"))) + .filter(new Comparison(LESS_THAN, new Reference(BIGINT, "join_expression"), new Reference(BIGINT, "a"))) .left(values("a")) .right( project( ImmutableMap.of( - "join_expression", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 2L))), - "parent_expression", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 1L)))), + "join_expression", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 2L))), + "parent_expression", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "b"), new Constant(BIGINT, 1L)))), values("b")))) .withExactOutputs("a", "b", "parent_expression")))); } @@ -280,14 +280,14 @@ public void testNotSupportedExpression() }).doesNotFire(); } - private static ComparisonExpression comparison(Operator operator, Expression left, Expression right) + private static Comparison comparison(Operator operator, Expression left, Expression right) { - return new ComparisonExpression(operator, left, right); + return new Comparison(operator, left, right); } - private ArithmeticBinaryExpression add(Symbol symbol, long value) + private Arithmetic add(Symbol symbol, long value) { - return new ArithmeticBinaryExpression( + return new Arithmetic( ADD_BIGINT, ADD, symbol.toSymbolReference(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java index 5b59977cb781..a2ac73cf8b12 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java @@ -37,8 +37,8 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.RuleTester; @@ -60,7 +60,7 @@ import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; @@ -136,7 +136,7 @@ public class TestPushJoinIntoTableScan @ParameterizedTest @MethodSource("testPushJoinIntoTableScanParams") - public void testPushJoinIntoTableScan(io.trino.sql.planner.plan.JoinType joinType, Optional filterComparisonOperator) + public void testPushJoinIntoTableScan(io.trino.sql.planner.plan.JoinType joinType, Optional filterComparisonOperator) { MockConnectorFactory connectorFactory = createMockConnectorFactory((session, applyJoinType, left, right, joinConditions, leftAssignments, rightAssignments) -> { assertThat(((MockConnectorTableHandle) left).getTableName()).isEqualTo(TABLE_A_SCHEMA_TABLE_NAME); @@ -180,7 +180,7 @@ public void testPushJoinIntoTableScan(io.trino.sql.planner.plan.JoinType joinTyp joinType, left, right, - new ComparisonExpression(filterComparisonOperator.get(), columnA1Symbol.toSymbolReference(), columnB1Symbol.toSymbolReference())); + new Comparison(filterComparisonOperator.get(), columnA1Symbol.toSymbolReference(), columnB1Symbol.toSymbolReference())); }) .matches( project( @@ -192,40 +192,40 @@ public static Stream testPushJoinIntoTableScanParams() { return Stream.of( Arguments.of(INNER, Optional.empty()), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.EQUAL)), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.LESS_THAN)), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL)), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.GREATER_THAN)), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL)), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.NOT_EQUAL)), - Arguments.of(INNER, Optional.of(ComparisonExpression.Operator.IS_DISTINCT_FROM)), + Arguments.of(INNER, Optional.of(Comparison.Operator.EQUAL)), + Arguments.of(INNER, Optional.of(Comparison.Operator.LESS_THAN)), + Arguments.of(INNER, Optional.of(Comparison.Operator.LESS_THAN_OR_EQUAL)), + Arguments.of(INNER, Optional.of(Comparison.Operator.GREATER_THAN)), + Arguments.of(INNER, Optional.of(Comparison.Operator.GREATER_THAN_OR_EQUAL)), + Arguments.of(INNER, Optional.of(Comparison.Operator.NOT_EQUAL)), + Arguments.of(INNER, Optional.of(Comparison.Operator.IS_DISTINCT_FROM)), Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.empty()), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.LESS_THAN)), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.GREATER_THAN)), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.NOT_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(ComparisonExpression.Operator.IS_DISTINCT_FROM)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.LESS_THAN)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.LESS_THAN_OR_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.GREATER_THAN)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.GREATER_THAN_OR_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.NOT_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.LEFT, Optional.of(Comparison.Operator.IS_DISTINCT_FROM)), Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.empty()), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.LESS_THAN)), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.GREATER_THAN)), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.NOT_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(ComparisonExpression.Operator.IS_DISTINCT_FROM)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.LESS_THAN)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.LESS_THAN_OR_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.GREATER_THAN)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.GREATER_THAN_OR_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.NOT_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.RIGHT, Optional.of(Comparison.Operator.IS_DISTINCT_FROM)), Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.empty()), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.LESS_THAN)), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.GREATER_THAN)), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.NOT_EQUAL)), - Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(ComparisonExpression.Operator.IS_DISTINCT_FROM))); + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.LESS_THAN)), + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.LESS_THAN_OR_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.GREATER_THAN)), + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.GREATER_THAN_OR_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.NOT_EQUAL)), + Arguments.of(io.trino.sql.planner.plan.JoinType.FULL, Optional.of(Comparison.Operator.IS_DISTINCT_FROM))); } /** @@ -275,9 +275,9 @@ public void testPushJoinIntoTableScanWithComplexFilter() INNER, left, right, - new ComparisonExpression( - ComparisonExpression.Operator.GREATER_THAN, - new ArithmeticBinaryExpression(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 44L), columnA1Symbol.toSymbolReference()), + new Comparison( + Comparison.Operator.GREATER_THAN, + new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Constant(BIGINT, 44L), columnA1Symbol.toSymbolReference()), columnB1Symbol.toSymbolReference())); }) .matches( @@ -628,7 +628,7 @@ private JoinType toSpiJoinType(io.trino.sql.planner.plan.JoinType joinType) }; } - private JoinCondition.Operator getConditionOperator(ComparisonExpression.Operator operator) + private JoinCondition.Operator getConditionOperator(Comparison.Operator operator) { switch (operator) { case EQUAL: diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java index 2408a67079a1..4449153432f5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -19,10 +19,10 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -32,8 +32,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.limit; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -57,12 +57,12 @@ public void testPushdownLimitNonIdentityProjection() Symbol a = p.symbol("a"); return p.limit(1, p.project( - Assignments.of(a, TRUE_LITERAL), + Assignments.of(a, TRUE), p.values())); }) .matches( strictProject( - ImmutableMap.of("b", expression(TRUE_LITERAL)), + ImmutableMap.of("b", expression(TRUE)), limit(1, values()))); } @@ -79,12 +79,12 @@ public void testPushdownLimitWithTiesNNonIdentityProjection() 1, ImmutableList.of(projectedA), p.project( - Assignments.of(projectedA, new SymbolReference(BIGINT, "a"), projectedB, new SymbolReference(BIGINT, "b")), + Assignments.of(projectedA, new Reference(BIGINT, "a"), projectedB, new Reference(BIGINT, "b")), p.values(a, b))); }) .matches( project( - ImmutableMap.of("projectedA", expression(new SymbolReference(BIGINT, "a")), "projectedB", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("projectedA", expression(new Reference(BIGINT, "a")), "projectedB", expression(new Reference(BIGINT, "b"))), limit(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -102,15 +102,15 @@ public void testPushdownLimitWithTiesThroughProjectionWithExpression() ImmutableList.of(projectedA), p.project( Assignments.of( - projectedA, new SymbolReference(BIGINT, "a"), - projectedC, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), + projectedA, new Reference(BIGINT, "a"), + projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), p.values(a, b))); }) .matches( project( ImmutableMap.of( - "projectedA", expression(new SymbolReference(BIGINT, "a")), - "projectedC", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))), + "projectedA", expression(new Reference(BIGINT, "a")), + "projectedC", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), limit(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -128,8 +128,8 @@ public void testDoNotPushdownLimitWithTiesThroughProjectionWithExpression() ImmutableList.of(projectedC), p.project( Assignments.of( - projectedA, new SymbolReference(BIGINT, "a"), - projectedC, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), + projectedA, new Reference(BIGINT, "a"), + projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), p.values(a, b))); }) .doesNotFire(); @@ -159,8 +159,8 @@ public void testDoesntPushDownLimitThroughExclusiveDereferences() return p.limit(1, p.project( Assignments.of( - p.symbol("b"), new SubscriptExpression(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L)), - p.symbol("c"), new SubscriptExpression(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 2L))), + p.symbol("b"), new Subscript(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L)), + p.symbol("c"), new Subscript(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 2L))), p.values(a))); }) .doesNotFire(); @@ -182,8 +182,8 @@ public void testLimitWithPreSortedInputs() ImmutableList.of(projectedC), p.project( Assignments.of( - projectedA, new SymbolReference(BIGINT, "a"), - projectedC, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), + projectedA, new Reference(BIGINT, "a"), + projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), p.values(a, b))); }) .doesNotFire(); @@ -201,13 +201,13 @@ projectedC, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference( ImmutableList.of(projectedA), p.project( Assignments.of( - projectedA, new SymbolReference(BIGINT, "a"), - projectedC, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), + projectedA, new Reference(BIGINT, "a"), + projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), p.values(a, b))); }) .matches( project( - ImmutableMap.of("projectedA", expression(new SymbolReference(BIGINT, "a")), "projectedC", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))), + ImmutableMap.of("projectedA", expression(new Reference(BIGINT, "a")), "projectedC", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), limit(1, ImmutableList.of(), true, ImmutableList.of("a"), values("a", "b")))); } @@ -222,13 +222,13 @@ public void testPushDownLimitThroughOverlappingDereferences() return p.limit(1, p.project( Assignments.of( - p.symbol("b"), new SubscriptExpression(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L)), + p.symbol("b"), new Subscript(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L)), p.symbol("c", rowType), a.toSymbolReference()), p.values(a))); }) .matches( project( - ImmutableMap.of("b", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SymbolReference(rowType, "a"), new Constant(INTEGER, 1L))), "c", expression(new SymbolReference(rowType, "a"))), + ImmutableMap.of("b", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Subscript(BIGINT, new Reference(rowType, "a"), new Constant(INTEGER, 1L))), "c", expression(new Reference(rowType, "a"))), limit(1, values("a")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java index e1efceb5f983..2e42a980f4bc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java @@ -26,10 +26,10 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TestingColumnHandle; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.Row; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.RuleTester; @@ -46,7 +46,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; public class TestPushMergeWriterUpdateIntoConnector @@ -70,7 +70,7 @@ public void testPushUpdateIntoConnector() Symbol rowId = p.symbol("row_id"); Symbol rowCount = p.symbol("row_count"); // set column name and constant update - Expression updateMergeRowExpression = new Row(ImmutableList.of(p.symbol("column_1").toSymbolReference(), new Constant(INTEGER, 1L), TRUE_LITERAL, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))); + Expression updateMergeRowExpression = new Row(ImmutableList.of(p.symbol("column_1").toSymbolReference(), new Constant(INTEGER, 1L), TRUE, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))); return p.tableFinish( p.merge( @@ -109,7 +109,7 @@ public void testPushUpdateIntoConnectorArithmeticExpression() Symbol rowCount = p.symbol("row_count"); // set arithmetic expression which we don't support yet Expression updateMergeRowExpression = new Row(ImmutableList.of(p.symbol("column_1").toSymbolReference(), - new ArithmeticBinaryExpression(MULTIPLY_INTEGER, ArithmeticBinaryExpression.Operator.MULTIPLY, p.symbol("col1").toSymbolReference(), new Constant(INTEGER, 5L)))); + new Arithmetic(MULTIPLY_INTEGER, Arithmetic.Operator.MULTIPLY, p.symbol("col1").toSymbolReference(), new Constant(INTEGER, 5L)))); return p.tableFinish( p.merge( @@ -147,7 +147,7 @@ public void testPushUpdateIntoConnectorUpdateAll() Symbol rowId = p.symbol("row_id"); Symbol rowCount = p.symbol("row_count"); // set function call, which represents update all columns statement - Expression updateMergeRowExpression = new Row(ImmutableList.of(new FunctionCall( + Expression updateMergeRowExpression = new Row(ImmutableList.of(new Call( ruleTester.getMetadata().resolveBuiltinFunction("from_base64", fromTypes(VARCHAR)), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("")))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java index 00a437d0ca9b..882fe2936539 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java @@ -19,7 +19,7 @@ import io.trino.sql.planner.plan.Assignments; import org.junit.jupiter.api.Test; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.offset; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; @@ -37,12 +37,12 @@ public void testPushdownOffsetNonIdentityProjection() return p.offset( 5, p.project( - Assignments.of(a, TRUE_LITERAL), + Assignments.of(a, TRUE), p.values())); }) .matches( strictProject( - ImmutableMap.of("b", expression(TRUE_LITERAL)), + ImmutableMap.of("b", expression(TRUE)), offset(5, values()))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java index 5b1964d5f088..568762e3e62b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -28,7 +28,7 @@ import static io.trino.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -55,19 +55,19 @@ public void testPushesPartialAggregationThroughJoin() ImmutableList.of(new EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), ImmutableList.of(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR")), ImmutableList.of(p.symbol("RIGHT_GROUP_BY")), - Optional.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "LEFT_NON_EQUI"), new SymbolReference(BIGINT, "RIGHT_NON_EQUI"))), + Optional.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "LEFT_NON_EQUI"), new Reference(BIGINT, "RIGHT_NON_EQUI"))), Optional.of(p.symbol("LEFT_HASH")), Optional.of(p.symbol("RIGHT_HASH")))) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("AVG", ImmutableList.of(new SymbolReference(BIGINT, "LEFT_AGGR"))), ImmutableList.of(DOUBLE)) + .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.aggregation("AVG", ImmutableList.of(new Reference(BIGINT, "LEFT_AGGR"))), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) .step(PARTIAL))) .matches(project(ImmutableMap.of( - "LEFT_GROUP_BY", PlanMatchPattern.expression(new SymbolReference(BIGINT, "LEFT_GROUP_BY")), - "RIGHT_GROUP_BY", PlanMatchPattern.expression(new SymbolReference(BIGINT, "RIGHT_GROUP_BY")), - "AVG", PlanMatchPattern.expression(new SymbolReference(DOUBLE, "AVG"))), + "LEFT_GROUP_BY", PlanMatchPattern.expression(new Reference(BIGINT, "LEFT_GROUP_BY")), + "RIGHT_GROUP_BY", PlanMatchPattern.expression(new Reference(BIGINT, "RIGHT_GROUP_BY")), + "AVG", PlanMatchPattern.expression(new Reference(DOUBLE, "AVG"))), join(INNER, builder -> builder .equiCriteria("LEFT_EQUI", "RIGHT_EQUI") - .filter(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "LEFT_NON_EQUI"), new SymbolReference(BIGINT, "RIGHT_NON_EQUI"))) + .filter(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "LEFT_NON_EQUI"), new Reference(BIGINT, "RIGHT_NON_EQUI"))) .left( aggregation( singleGroupingSet("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_HASH"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java index 5e755254ae69..878857527a97 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java @@ -39,12 +39,12 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.testing.TestingTransactionHandle; import org.junit.jupiter.api.BeforeAll; @@ -60,10 +60,10 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -129,7 +129,7 @@ public void testEliminateTableScanWhenNoLayoutExist() { tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("G"))), + new Comparison(EQUAL, new Reference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("G"))), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), @@ -143,7 +143,7 @@ public void testReplaceWithExistsWhenNoLayoutExist() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -159,7 +159,7 @@ public void testConsumesDeterministicPredicateIfNewDomainIsSame() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -178,7 +178,7 @@ public void testConsumesDeterministicPredicateIfNewDomainIsWider() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 45L)))), + new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 45L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -199,7 +199,7 @@ public void testConsumesDeterministicPredicateIfNewDomainIsNarrower() Map filterConstraint = ImmutableMap.of("orderstatus", singleValue(orderStatusType, utf8Slice("O"))); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("O"))), new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("F"))))), + new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("O"))), new Comparison(EQUAL, new Reference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("F"))))), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", orderStatusType)), @@ -216,25 +216,25 @@ public void testDoesNotConsumeRemainingPredicateIfNewDomainIsWider() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new LogicalExpression( + new Logical( AND, ImmutableList.of( - new ComparisonExpression( + new Comparison( EQUAL, functionResolution .functionCallBuilder("rand") .build(), new Constant(BIGINT, 42L)), // non-translatable to connector expression - new CoalesceExpression( + new Coalesce( new Constant(BOOLEAN, null), - new ComparisonExpression( + new Comparison( EQUAL, - new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), new Constant(BIGINT, 44L))), - LogicalExpression.or( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 45L))))), + Logical.or( + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 45L))))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -243,16 +243,16 @@ public void testDoesNotConsumeRemainingPredicateIfNewDomainIsWider() columnHandle, NullableValue.of(BIGINT, (long) 44)))))) .matches( filter( - LogicalExpression.and( - new ComparisonExpression( + Logical.and( + new Comparison( EQUAL, functionResolution .functionCallBuilder("rand") .build(), new Constant(BIGINT, 42L)), - new ComparisonExpression( + new Comparison( EQUAL, - new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), new Constant(BIGINT, 44L))), constrainedTableScanWithTableLayout( "nation", @@ -266,7 +266,7 @@ public void testDoesNotFireOnNonDeterministicPredicate() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression( + new Comparison( EQUAL, functionResolution .functionCallBuilder("rand") @@ -285,7 +285,7 @@ public void testDoesNotFireIfRuleNotChangePlan() { tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 17L)), new Constant(BIGINT, 44L)), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 15L)), new Constant(BIGINT, 43L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 17L)), new Constant(BIGINT, 44L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 15L)), new Constant(BIGINT, 43L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -300,7 +300,7 @@ public void testRuleAddedTableLayoutToFilterTableScan() Map filterConstraint = ImmutableMap.of("orderstatus", singleValue(createVarcharType(1), utf8Slice("F"))); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("F"))), + new Comparison(EQUAL, new Reference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("F"))), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), @@ -315,9 +315,9 @@ public void testNonDeterministicPredicate() Type orderStatusType = createVarcharType(1); tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - LogicalExpression.and( - new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("O"))), - new ComparisonExpression( + Logical.and( + new Comparison(EQUAL, new Reference(createVarcharType(1), "orderstatus"), new Constant(createVarcharType(1), utf8Slice("O"))), + new Comparison( EQUAL, functionResolution .functionCallBuilder("rand") @@ -329,7 +329,7 @@ public void testNonDeterministicPredicate() ImmutableMap.of(p.symbol("orderstatus", orderStatusType), new TpchColumnHandle("orderstatus", orderStatusType))))) .matches( filter( - new ComparisonExpression( + new Comparison( EQUAL, functionResolution .functionCallBuilder("rand") @@ -350,7 +350,7 @@ public void testPartitioningChanged() assertThatThrownBy(() -> tester().assertThat(pushPredicateIntoTableScan) .withSession(session) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "col"), new Constant(VARCHAR, utf8Slice("G"))), + new Comparison(EQUAL, new Reference(VARCHAR, "col"), new Constant(VARCHAR, utf8Slice("G"))), p.tableScan( mockTableHandle(CONNECTOR_PARTITIONED_TABLE_HANDLE_TO_UNPARTITIONED), ImmutableList.of(p.symbol("col", VARCHAR)), @@ -362,7 +362,7 @@ public void testPartitioningChanged() tester().assertThat(pushPredicateIntoTableScan) .withSession(session) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "col"), new Constant(VARCHAR, utf8Slice("G"))), + new Comparison(EQUAL, new Reference(VARCHAR, "col"), new Constant(VARCHAR, utf8Slice("G"))), p.tableScan( mockTableHandle(CONNECTOR_PARTITIONED_TABLE_HANDLE), ImmutableList.of(p.symbol("col", VARCHAR)), @@ -387,7 +387,7 @@ public void testEliminateTableScanWhenPredicateIsNull() tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, null)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, null)), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -396,7 +396,7 @@ public void testEliminateTableScanWhenPredicateIsNull() tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), new Constant(BOOLEAN, null))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), new Constant(BOOLEAN, null))), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java index b71b5282bf56..34ded67c19dc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.java @@ -19,11 +19,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -34,11 +34,11 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -59,7 +59,7 @@ public void testRowNumberSymbolPruned() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), p.project( Assignments.identity(a), p.rowNumber( @@ -79,7 +79,7 @@ public void testNoUpperBoundForRowNumberSymbol() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), + new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), p.project( Assignments.identity(a, rowNumber), p.rowNumber( @@ -99,7 +99,7 @@ public void testNonPositiveUpperBoundForRowNumberSymbol() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, -10L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, -10L)))), p.project( Assignments.identity(a, rowNumber), p.rowNumber( @@ -119,7 +119,7 @@ public void testPredicateNotSatisfied() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 2L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 2L)), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)))), p.project( Assignments.identity(rowNumber), p.rowNumber( @@ -129,9 +129,9 @@ public void testPredicateNotSatisfied() p.values(a)))); }) .matches(filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 2L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 2L)), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)))), project( - ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "row_number"))), + ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "row_number"))), rowNumber( pattern -> pattern .maxRowCountPerPartition(Optional.of(4)), @@ -147,7 +147,7 @@ public void testPredicateNotSatisfiedAndMaxRowCountNotUpdated() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 2L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 2L)), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)))), p.project( Assignments.identity(rowNumber), p.rowNumber( @@ -167,7 +167,7 @@ public void testPredicateSatisfied() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), p.project( Assignments.identity(rowNumber), p.rowNumber( @@ -177,7 +177,7 @@ public void testPredicateSatisfied() p.values(a)))); }) .matches(project( - ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "row_number"))), + ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "row_number"))), rowNumber( pattern -> pattern .maxRowCountPerPartition(Optional.of(3)), @@ -189,7 +189,7 @@ public void testPredicateSatisfied() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 3L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 3L)), p.project( Assignments.identity(rowNumber), p.rowNumber( @@ -199,7 +199,7 @@ public void testPredicateSatisfied() p.values(a)))); }) .matches(project( - ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "row_number"))), + ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "row_number"))), rowNumber( pattern -> pattern .maxRowCountPerPartition(Optional.of(2)), @@ -215,7 +215,7 @@ public void testPredicatePartiallySatisfied() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)))), p.project( Assignments.identity(rowNumber, a), p.rowNumber( @@ -225,9 +225,9 @@ public void testPredicatePartiallySatisfied() p.values(a)))); }) .matches(filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)), project( - ImmutableMap.of("row_number", expression(new SymbolReference(BIGINT, "row_number")), "a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("row_number", expression(new Reference(BIGINT, "row_number")), "a", expression(new Reference(BIGINT, "a"))), rowNumber( pattern -> pattern .maxRowCountPerPartition(Optional.of(3)), @@ -239,7 +239,7 @@ public void testPredicatePartiallySatisfied() Symbol a = p.symbol("a"); Symbol rowNumber = p.symbol("row_number"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "row_number"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "row_number"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)))), p.project( Assignments.identity(rowNumber), p.rowNumber( @@ -249,9 +249,9 @@ public void testPredicatePartiallySatisfied() p.values(a)))); }) .matches(filter( - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "row_number"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "row_number"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)), project( - ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "row_number"))), + ImmutableMap.of("row_number", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "row_number"))), rowNumber( pattern -> pattern .maxRowCountPerPartition(Optional.of(3)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java index 27ea257b8f6e..b788a4df0089 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java @@ -19,11 +19,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher; @@ -40,11 +40,11 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -74,7 +74,7 @@ private void assertRankingSymbolPruned(Function rankingFunction) Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), p.project( Assignments.identity(a), p.window( @@ -101,7 +101,7 @@ private void assertNoUpperBoundForRankingSymbol(Function rankingFunction) Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), + new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), p.project( Assignments.identity(a, ranking), p.window( @@ -128,7 +128,7 @@ private void assertNonPositiveUpperBoundForRankingSymbol(Function rankingFunctio Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, -10L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, -10L)))), p.project( Assignments.identity(a, ranking), p.window( @@ -155,7 +155,7 @@ private void assertPredicateNotSatisfied(Function rankingFunction, RankingType r Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 2L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 2L)), new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)))), p.project( Assignments.identity(ranking), p.window( @@ -166,9 +166,9 @@ private void assertPredicateNotSatisfied(Function rankingFunction, RankingType r p.values(a)))); }) .matches(filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 2L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 2L)), new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)))), project( - ImmutableMap.of("ranking", expression(new SymbolReference(BIGINT, "ranking"))), + ImmutableMap.of("ranking", expression(new Reference(BIGINT, "ranking"))), topNRanking( pattern -> pattern .specification( @@ -196,7 +196,7 @@ private void assertPredicateSatisfied(Function rankingFunction, RankingType rank Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), p.project( Assignments.identity(ranking), p.window( @@ -207,7 +207,7 @@ private void assertPredicateSatisfied(Function rankingFunction, RankingType rank p.values(a)))); }) .matches(project( - ImmutableMap.of("ranking", expression(new SymbolReference(BIGINT, "ranking"))), + ImmutableMap.of("ranking", expression(new Reference(BIGINT, "ranking"))), topNRanking( pattern -> pattern .specification( @@ -235,7 +235,7 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)))), p.project( Assignments.identity(ranking, a), p.window( @@ -246,9 +246,9 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking p.values(a)))); }) .matches(filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)), project( - ImmutableMap.of("ranking", expression(new SymbolReference(BIGINT, "ranking")), "a", expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("ranking", expression(new Reference(BIGINT, "ranking")), "a", expression(new Reference(BIGINT, "a"))), topNRanking( pattern -> pattern .specification( @@ -266,7 +266,7 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking Symbol a = p.symbol("a"); Symbol ranking = p.symbol("ranking"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "ranking"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "ranking"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)))), p.project( Assignments.identity(ranking), p.window( @@ -277,9 +277,9 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking p.values(a)))); }) .matches(filter( - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "ranking"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "ranking"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 0L)), project( - ImmutableMap.of("ranking", expression(new SymbolReference(BIGINT, "ranking"))), + ImmutableMap.of("ranking", expression(new Reference(BIGINT, "ranking"))), topNRanking( pattern -> pattern .specification( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index 2e4464f34053..8b7fd223143e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -42,8 +42,8 @@ import io.trino.sql.PlannerContext; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.RuleTester; import io.trino.sql.planner.plan.Assignments; @@ -101,7 +101,7 @@ public void testDoesNotFire() .on(p -> { Symbol symbol = p.symbol(columnName, columnType); return p.project( - Assignments.of(p.symbol("symbol_dereference", BIGINT), new SubscriptExpression(BIGINT, symbol.toSymbolReference(), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("symbol_dereference", BIGINT), new Subscript(BIGINT, symbol.toSymbolReference(), new Constant(INTEGER, 1L))), p.tableScan( ruleTester.getCurrentCatalogTableHandle(TEST_SCHEMA, TEST_TABLE), ImmutableList.of(symbol), @@ -136,7 +136,7 @@ public void testPushProjection() // Prepare project node assignments Assignments inputProjections = Assignments.builder() .put(identity, baseColumn.toSymbolReference()) - .put(dereference, new SubscriptExpression(BIGINT, baseColumn.toSymbolReference(), new Constant(INTEGER, 1L))) + .put(dereference, new Subscript(BIGINT, baseColumn.toSymbolReference(), new Constant(INTEGER, 1L))) .put(constant, new Constant(INTEGER, 5L)) .build(); @@ -181,7 +181,7 @@ public void testPushProjection() e -> e.getKey().getName(), e -> { if (e.getValue() instanceof String value) { - return expression(new SymbolReference(BIGINT, value)); + return expression(new Reference(BIGINT, value)); } if (e.getValue() instanceof Expression value) { return expression(value); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index 888f4e0f9667..742476e99d4e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -19,9 +19,9 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.SortOrder; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -30,7 +30,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -94,7 +94,7 @@ public void testSimpleMultipleInputs() return p.project( Assignments.of( x, new Constant(INTEGER, 3L), - c2, new SymbolReference(BIGINT, "c")), + c2, new Reference(BIGINT, "c")), p.exchange(e -> e .addSource( p.values(a)) @@ -129,7 +129,7 @@ public void testHashMapping() Symbol cTimes5 = p.symbol("c_times_5"); return p.project( Assignments.of( - cTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 5L))), + cTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "c"), new Constant(INTEGER, 5L))), p.exchange(e -> e .addSource( p.values(a, h1)) @@ -144,9 +144,9 @@ cTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe exchange( strictProject( ImmutableMap.of( - "a", expression(new SymbolReference(INTEGER, "a")), - "h_1", expression(new SymbolReference(BIGINT, "h_1")), - "a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + "a", expression(new Reference(INTEGER, "a")), + "h_1", expression(new Reference(BIGINT, "h_1")), + "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), values(ImmutableList.of("a", "h_1")))))); } @@ -165,7 +165,7 @@ public void testSkipIdentityProjectionIfOutputPresent() Symbol aTimes5 = p.symbol("a_times_5"); return p.project( Assignments.of( - aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), a, a.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -175,7 +175,7 @@ aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe .matches( exchange( strictProject( - ImmutableMap.of("a_0", expression(new SymbolReference(INTEGER, "a")), "a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), values(ImmutableList.of("a"))))); // In the following example, the Projection over Exchange has got an identity assignment (b -> b). @@ -191,7 +191,7 @@ aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe Symbol b = p.symbol("b"); return p.project( Assignments.of( - bTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + bTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), b, b.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -201,7 +201,7 @@ bTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe .matches( exchange( strictProject( - ImmutableMap.of("a_0", expression(new SymbolReference(INTEGER, "a")), "a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), values(ImmutableList.of("a"))))); } @@ -220,7 +220,7 @@ public void testDoNotSkipIdentityProjectionIfOutputAbsent() Symbol aTimes5 = p.symbol("a_times_5"); return p.project( Assignments.of( - aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), a, a.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -230,7 +230,7 @@ aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe .matches( exchange( strictProject( - ImmutableMap.of("a_0", expression(new SymbolReference(INTEGER, "a")), "a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + ImmutableMap.of("a_0", expression(new Reference(INTEGER, "a")), "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), values(ImmutableList.of("a"))))); // In the following example, the Projection over Exchange has got an identity assignment (b -> b). @@ -246,7 +246,7 @@ aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe Symbol b = p.symbol("b"); return p.project( Assignments.of( - bTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + bTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), b, b.toSymbolReference()), p.exchange(e -> e .addSource(p.values(a)) @@ -257,8 +257,8 @@ bTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolRe exchange( strictProject( ImmutableMap.of( - "a_0", expression(new SymbolReference(INTEGER, "a")), - "a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))), + "a_0", expression(new Reference(INTEGER, "a")), + "a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))), values(ImmutableList.of("a"))))); } @@ -275,9 +275,9 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() Symbol hTimes5 = p.symbol("h_times_5"); return p.project( Assignments.builder() - .put(aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L))) - .put(bTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L))) - .put(hTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "h"), new Constant(INTEGER, 5L))) + .put(aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))) + .put(bTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L))) + .put(hTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))) .build(), p.exchange(e -> e .addSource( @@ -295,11 +295,11 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() values( ImmutableList.of("a", "b", "h")) ).withNumberOfOutputColumns(5) - .withAlias("b", expression(new SymbolReference(INTEGER, "b"))) - .withAlias("h", expression(new SymbolReference(INTEGER, "h"))) - .withAlias("a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))) - .withAlias("b_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)))) - .withAlias("h_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "h"), new Constant(INTEGER, 5L))))) + .withAlias("b", expression(new Reference(INTEGER, "b"))) + .withAlias("h", expression(new Reference(INTEGER, "h"))) + .withAlias("a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))) + .withAlias("b_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)))) + .withAlias("h_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))))) ).withNumberOfOutputColumns(3) .withExactOutputs("a_times_5", "b_times_5", "h_times_5")); } @@ -319,9 +319,9 @@ public void testOrderingColumnsArePreserved() OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(sortSymbol), ImmutableMap.of(sortSymbol, SortOrder.ASC_NULLS_FIRST)); return p.project( Assignments.builder() - .put(aTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L))) - .put(bTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L))) - .put(hTimes5, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "h"), new Constant(INTEGER, 5L))) + .put(aTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L))) + .put(bTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L))) + .put(hTimes5, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L))) .build(), p.exchange(e -> e .addSource( @@ -338,10 +338,10 @@ public void testOrderingColumnsArePreserved() values( ImmutableList.of("a", "b", "h", "sortSymbol"))) .withNumberOfOutputColumns(4) - .withAlias("a_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)))) - .withAlias("b_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)))) - .withAlias("h_times_5", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "h"), new Constant(INTEGER, 5L)))) - .withAlias("sortSymbol", expression(new SymbolReference(INTEGER, "sortSymbol")))) + .withAlias("a_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)))) + .withAlias("b_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)))) + .withAlias("h_times_5", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "h"), new Constant(INTEGER, 5L)))) + .withAlias("sortSymbol", expression(new Reference(INTEGER, "sortSymbol")))) ).withNumberOfOutputColumns(3) .withExactOutputs("a_times_5", "b_times_5", "h_times_5")); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java index 2c4afe9c0dde..0cee5e499821 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java @@ -19,9 +19,9 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -41,7 +41,7 @@ import static io.trino.metadata.AbstractMockMetadata.dummyMetadata; import static io.trino.metadata.FunctionManager.createTestingFunctionManager; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.assertions.PlanAssert.assertPlan; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -75,14 +75,14 @@ public void testPushesProjectionThroughJoin() ProjectNode planNode = p.project( Assignments.of( - a3, new ArithmeticNegation(a2.toSymbolReference()), - b2, new ArithmeticNegation(b1.toSymbolReference())), + a3, new Negation(a2.toSymbolReference()), + b2, new Negation(b1.toSymbolReference())), p.join( INNER, // intermediate non-identity projections should be fully inlined p.project( Assignments.of( - a2, new ArithmeticNegation(a0.toSymbolReference()), + a2, new Negation(a0.toSymbolReference()), a1, a1.toSymbolReference()), p.project( Assignments.builder() @@ -106,16 +106,16 @@ a2, new ArithmeticNegation(a0.toSymbolReference()), .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol(UNKNOWN, "a1"), new Symbol(UNKNOWN, "b1")))) .left( strictProject(ImmutableMap.of( - "a3", expression(new ArithmeticNegation(new ArithmeticNegation(new SymbolReference(BIGINT, "a0")))), - "a1", expression(new SymbolReference(BIGINT, "a1"))), + "a3", expression(new Negation(new Negation(new Reference(BIGINT, "a0")))), + "a1", expression(new Reference(BIGINT, "a1"))), strictProject(ImmutableMap.of( - "a0", expression(new SymbolReference(BIGINT, "a0")), - "a1", expression(new SymbolReference(BIGINT, "a1"))), + "a0", expression(new Reference(BIGINT, "a0")), + "a1", expression(new Reference(BIGINT, "a1"))), PlanMatchPattern.values("a0", "a1")))) .right( strictProject(ImmutableMap.of( - "b2", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "b1"))), - "b1", expression(new SymbolReference(BIGINT, "b1"))), + "b2", expression(new Negation(new Reference(BIGINT, "b1"))), + "b1", expression(new Reference(BIGINT, "b1"))), PlanMatchPattern.values("b0", "b1")))) .withExactOutputs("a3", "b2")); } @@ -130,7 +130,7 @@ public void testDoesNotPushStraddlingProjection() ProjectNode planNode = p.project( Assignments.of( - c, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, a.toSymbolReference(), b.toSymbolReference())), + c, new Arithmetic(ADD_BIGINT, ADD, a.toSymbolReference(), b.toSymbolReference())), p.join( INNER, p.values(a), @@ -149,7 +149,7 @@ public void testDoesNotPushProjectionThroughOuterJoin() ProjectNode planNode = p.project( Assignments.of( - c, new ArithmeticNegation(a.toSymbolReference())), + c, new Negation(a.toSymbolReference())), p.join( LEFT, p.values(a), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index 6b575fce5958..8b3c325f2b45 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -20,10 +20,10 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -33,7 +33,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.union; @@ -95,8 +95,8 @@ public void test() Symbol w = p.symbol("w", ROW_TYPE); return p.project( Assignments.of( - cTimes3, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, c.toSymbolReference(), new Constant(INTEGER, 3L)), - dX, new SubscriptExpression(INTEGER, new SymbolReference(ROW_TYPE, "d"), new Constant(INTEGER, 1L))), + cTimes3, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, c.toSymbolReference(), new Constant(INTEGER, 3L)), + dX, new Subscript(INTEGER, new Reference(ROW_TYPE, "d"), new Constant(INTEGER, 1L))), p.union( ImmutableListMultimap.builder() .put(c, a) @@ -111,10 +111,10 @@ dX, new SubscriptExpression(INTEGER, new SymbolReference(ROW_TYPE, "d"), new Con .matches( union( project( - ImmutableMap.of("a_times_3", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 3L))), "z_x", expression(new SubscriptExpression(INTEGER, new SymbolReference(ROW_TYPE, "z"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("a_times_3", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 3L))), "z_x", expression(new Subscript(INTEGER, new Reference(ROW_TYPE, "z"), new Constant(INTEGER, 1L)))), values(ImmutableList.of("a", "z"))), project( - ImmutableMap.of("b_times_3", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 3L))), "w_x", expression(new SubscriptExpression(INTEGER, new SymbolReference(ROW_TYPE, "w"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("b_times_3", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Constant(INTEGER, 3L))), "w_x", expression(new Subscript(INTEGER, new Reference(ROW_TYPE, "w"), new Constant(INTEGER, 1L)))), values(ImmutableList.of("b", "w")))) .withNumberOfOutputColumns(2) .withAlias("a_times_3") diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java index e85283dcdf8e..99103df58032 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java @@ -19,11 +19,11 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.BooleanLiteral; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Booleans; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -34,7 +34,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; @@ -66,12 +66,12 @@ public void testPushdownTopNNonIdentityProjection() 1, ImmutableList.of(projectedA), p.project( - Assignments.of(projectedA, new SymbolReference(BIGINT, "a"), projectedB, new SymbolReference(BIGINT, "b")), + Assignments.of(projectedA, new Reference(BIGINT, "a"), projectedB, new Reference(BIGINT, "b")), p.values(a, b))); }) .matches( project( - ImmutableMap.of("projectedA", expression(new SymbolReference(BIGINT, "a")), "projectedB", expression(new SymbolReference(BIGINT, "b"))), + ImmutableMap.of("projectedA", expression(new Reference(BIGINT, "a")), "projectedB", expression(new Reference(BIGINT, "b"))), topN(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -89,15 +89,15 @@ public void testPushdownTopNNonIdentityProjectionWithExpression() ImmutableList.of(projectedA), p.project( Assignments.of( - projectedA, new SymbolReference(BIGINT, "a"), - projectedC, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b"))), + projectedA, new Reference(BIGINT, "a"), + projectedC, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b"))), p.values(a, b))); }) .matches( project( ImmutableMap.of( - "projectedA", expression(new SymbolReference(BIGINT, "a")), - "projectedC", expression(new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))), + "projectedA", expression(new Reference(BIGINT, "a")), + "projectedC", expression(new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))), topN(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -125,9 +125,9 @@ public void testDoNotPushdownTopNThroughProjectionOverFilterOverTableScan() 1, ImmutableList.of(projectedA), p.project( - Assignments.of(projectedA, new SymbolReference(BIGINT, "a")), + Assignments.of(projectedA, new Reference(BIGINT, "a")), p.filter( - BooleanLiteral.TRUE_LITERAL, + Booleans.TRUE, p.tableScan(ImmutableList.of(), ImmutableMap.of())))); }).doesNotFire(); } @@ -143,7 +143,7 @@ public void testDoNotPushdownTopNThroughProjectionOverTableScan() 1, ImmutableList.of(projectedA), p.project( - Assignments.of(projectedA, new SymbolReference(BIGINT, "a")), + Assignments.of(projectedA, new Reference(BIGINT, "a")), p.tableScan( ImmutableList.of(a), ImmutableMap.of(a, new TestingMetadata.TestingColumnHandle("a"))))); @@ -161,8 +161,8 @@ public void testDoesntPushDownTopNThroughExclusiveDereferences() ImmutableList.of(p.symbol("c")), p.project( Assignments.builder() - .put(p.symbol("b"), new SubscriptExpression(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L))) - .put(p.symbol("c"), new SubscriptExpression(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 2L))) + .put(p.symbol("b"), new Subscript(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L))) + .put(p.symbol("c"), new Subscript(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 2L))) .build(), p.values(a))); }).doesNotFire(); @@ -180,7 +180,7 @@ public void testPushTopNThroughOverlappingDereferences() ImmutableList.of(d), p.project( Assignments.builder() - .put(p.symbol("b"), new SubscriptExpression(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L))) + .put(p.symbol("b"), new Subscript(BIGINT, a.toSymbolReference(), new Constant(INTEGER, 1L))) .put(p.symbol("c", rowType), a.toSymbolReference()) .putIdentity(d) .build(), @@ -188,7 +188,7 @@ public void testPushTopNThroughOverlappingDereferences() }) .matches( project( - ImmutableMap.of("b", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SubscriptExpression(BIGINT, new SymbolReference(BIGINT, "a"), new Constant(INTEGER, 1L))), "c", expression(new SymbolReference(BIGINT, "a")), "d", expression(new SymbolReference(BIGINT, "d"))), + ImmutableMap.of("b", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Subscript(BIGINT, new Reference(BIGINT, "a"), new Constant(INTEGER, 1L))), "c", expression(new Reference(BIGINT, "a")), "d", expression(new Reference(BIGINT, "d"))), topN( 1, ImmutableList.of(sort("d", ASCENDING, FIRST)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoRowNumber.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoRowNumber.java index 534bddf2d8a0..a67f6062e433 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoRowNumber.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoRowNumber.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -28,10 +28,10 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.rowNumber; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -47,7 +47,7 @@ public void testSourceRowNumber() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)), p.rowNumber( ImmutableList.of(a), Optional.empty(), @@ -65,7 +65,7 @@ public void testSourceRowNumber() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)), p.rowNumber( ImmutableList.of(a), Optional.of(10), @@ -83,7 +83,7 @@ public void testSourceRowNumber() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new Constant(BIGINT, 3L), new SymbolReference(BIGINT, "row_number_1")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Constant(BIGINT, 3L), new Reference(BIGINT, "row_number_1")), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 5L)))), p.rowNumber( ImmutableList.of(a), Optional.of(10), @@ -92,7 +92,7 @@ public void testSourceRowNumber() }) .matches( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new Constant(BIGINT, 3L), new SymbolReference(BIGINT, "row_number_1")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Constant(BIGINT, 3L), new Reference(BIGINT, "row_number_1")), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 5L)))), rowNumber(rowNumber -> rowNumber .maxRowCountPerPartition(Optional.of(4)) .partitionBy(ImmutableList.of("a")), @@ -104,7 +104,7 @@ public void testSourceRowNumber() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 5L)), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)))), p.rowNumber( ImmutableList.of(a), Optional.of(10), @@ -113,7 +113,7 @@ public void testSourceRowNumber() }) .matches( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), + new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), rowNumber(rowNumber -> rowNumber .maxRowCountPerPartition(Optional.of(4)) .partitionBy(ImmutableList.of("a")), @@ -128,7 +128,7 @@ public void testNoOutputsThroughRowNumber() .on(p -> { Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, -100L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, -100L)), p.rowNumber(ImmutableList.of(p.symbol("a")), Optional.empty(), rowNumberSymbol, p.values(p.symbol("a")))); }) @@ -142,7 +142,7 @@ public void testDoNotFire() .on(p -> { Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "not_row_number"), new Cast(new Constant(INTEGER, 100L), BIGINT)), + new Comparison(LESS_THAN, new Reference(BIGINT, "not_row_number"), new Cast(new Constant(INTEGER, 100L), BIGINT)), p.rowNumber(ImmutableList.of(p.symbol("a")), Optional.empty(), rowNumberSymbol, p.values(p.symbol("a"), p.symbol("not_row_number")))); }) @@ -152,7 +152,7 @@ public void testDoNotFire() .on(p -> { Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number_1"), new Cast(new Constant(INTEGER, 100L), BIGINT)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number_1"), new Cast(new Constant(INTEGER, 100L), BIGINT)), p.rowNumber(ImmutableList.of(p.symbol("a")), Optional.empty(), rowNumberSymbol, p.values(p.symbol("a")))); }) @@ -163,7 +163,7 @@ public void testDoNotFire() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new Cast(new Constant(INTEGER, 3L), BIGINT), new SymbolReference(BIGINT, "row_number_1")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Cast(new Constant(INTEGER, 5L), BIGINT)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Cast(new Constant(INTEGER, 3L), BIGINT), new Reference(BIGINT, "row_number_1")), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Cast(new Constant(INTEGER, 5L), BIGINT)))), p.rowNumber( ImmutableList.of(a), Optional.of(4), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java index 6ce4da0e68dd..004cd58deac4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java @@ -18,10 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.spi.connector.SortOrder; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher; @@ -35,9 +35,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.topNRanking; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -64,7 +64,7 @@ private void assertEliminateFilter(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "rank_1"), new Constant(BIGINT, 100L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "rank_1"), new Constant(BIGINT, 100L)), p.window( new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rankSymbol, newWindowNodeFunction(ranking, a)), @@ -94,14 +94,14 @@ private void assertKeepFilter(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new Constant(BIGINT, 3L), new SymbolReference(BIGINT, "row_number_1")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Constant(BIGINT, 3L), new Reference(BIGINT, "row_number_1")), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)))), p.window( new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a")))); }) .matches(filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new Constant(BIGINT, 3L), new SymbolReference(BIGINT, "row_number_1")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)))), + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN, new Constant(BIGINT, 3L), new Reference(BIGINT, "row_number_1")), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)))), topNRanking(pattern -> pattern .partial(false) .maxRankingPerPartition(99) @@ -119,16 +119,16 @@ private void assertKeepFilter(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)))), + new Logical(AND, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(BIGINT, "row_number_1"), new Constant(BIGINT, 100L)), + new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)))), p.window( new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a")))); }) .matches(filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L)), + new Comparison(EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 1L)), topNRanking(pattern -> pattern .partial(false) .maxRankingPerPartition(99) @@ -157,7 +157,7 @@ private void assertNoUpperBound(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter( - new ComparisonExpression(LESS_THAN, new Cast(new Constant(INTEGER, 3L), BIGINT), new SymbolReference(BIGINT, "row_number_1")), + new Comparison(LESS_THAN, new Cast(new Constant(INTEGER, 3L), BIGINT), new Reference(BIGINT, "row_number_1")), p.window( new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveAggregationInSemiJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveAggregationInSemiJoin.java index 7f59fd132221..58972485424c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveAggregationInSemiJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveAggregationInSemiJoin.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -79,7 +79,7 @@ private static PlanNode semiJoinWithAggregationAsFilteringSource(PlanBuilder p) p.values(leftKey), p.aggregation(builder -> builder .globalGrouping() - .addAggregation(rightKey, aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "rightValue"))), ImmutableList.of(BIGINT)) + .addAggregation(rightKey, aggregation("count", ImmutableList.of(new Reference(BIGINT, "rightValue"))), ImmutableList.of(BIGINT)) .source(p.values(p.symbol("rightValue"))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyExceptBranches.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyExceptBranches.java index 7c3ac369fdf5..97848067ed88 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyExceptBranches.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyExceptBranches.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode.Step; @@ -109,7 +109,7 @@ public void testReplaceRedundantExceptAll() }) .matches( project( - ImmutableMap.of("output", expression(new SymbolReference(BIGINT, "input1"))), + ImmutableMap.of("output", expression(new Reference(BIGINT, "input1"))), values(ImmutableList.of("input1"), ImmutableList.of(ImmutableList.of(new Constant(BIGINT, null)))))); } @@ -139,7 +139,7 @@ public void testReplaceRedundantExceptDistinct() Optional.empty(), Step.SINGLE, project( - ImmutableMap.of("output", expression(new SymbolReference(BIGINT, "input1"))), + ImmutableMap.of("output", expression(new Reference(BIGINT, "input1"))), values(ImmutableList.of("input1"), ImmutableList.of(ImmutableList.of(new Constant(BIGINT, null))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyUnionBranches.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyUnionBranches.java index 53267360923c..d1224f3b382b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyUnionBranches.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveEmptyUnionBranches.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -104,7 +104,7 @@ public void testReplaceUnionWithProjection() }) .matches( project( - ImmutableMap.of("output", expression(new SymbolReference(BIGINT, "input1"))), + ImmutableMap.of("output", expression(new Reference(BIGINT, "input1"))), values(ImmutableList.of("input1"), ImmutableList.of(ImmutableList.of(new Constant(BIGINT, null)))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveFullSample.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveFullSample.java index 6cef49367446..e24858425b0d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveFullSample.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveFullSample.java @@ -15,15 +15,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.SampleNode.Type; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -51,7 +51,7 @@ public void test() 1.0, Type.BERNOULLI, p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( @@ -59,7 +59,7 @@ public void test() ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 11L))))))) // TODO: verify contents .matches(filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), values(ImmutableMap.of("a", 0, "b", 1)))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantEnforceSingleRowNode.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantEnforceSingleRowNode.java index 194c2e2e585b..2e2776cbad70 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantEnforceSingleRowNode.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantEnforceSingleRowNode.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public void testRemoveEnforceWhenSourceScalar() { tester().assertThat(new RemoveRedundantEnforceSingleRowNode()) .on(p -> p.enforceSingleRow(p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(p.values(p.symbol("a")))))) .matches(node(AggregationNode.class, values("a"))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java index 0e15c42c156b..75212b470eb8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java @@ -21,8 +21,8 @@ import io.trino.testing.TestingMetadata; import org.junit.jupiter.api.Test; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -40,7 +40,7 @@ public void testExistsFalse() p.values(0))) .matches( project( - ImmutableMap.of("exists", expression(FALSE_LITERAL)), + ImmutableMap.of("exists", expression(FALSE)), values())); } @@ -54,7 +54,7 @@ public void testExistsTrue() p.values(1))) .matches( project( - ImmutableMap.of("exists", expression(TRUE_LITERAL)), + ImmutableMap.of("exists", expression(TRUE)), values())); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantLimit.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantLimit.java index ff778a93d5fa..cbafa9f9aabc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantLimit.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantLimit.java @@ -15,9 +15,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode; @@ -27,7 +27,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; import static io.trino.sql.planner.iterative.rule.test.PlanBuilder.aggregation; @@ -43,7 +43,7 @@ public void test() p.limit( 10, p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(p.values(p.symbol("foo")))))) .matches( @@ -73,7 +73,7 @@ public void testForZeroLimit() p.limit( 0, p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( @@ -93,7 +93,7 @@ public void testLimitWithPreSortedInputs() true, ImmutableList.of(p.symbol("c")), p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(p.values(p.symbol("foo")))))) .matches( @@ -107,7 +107,7 @@ public void testLimitWithPreSortedInputs() true, ImmutableList.of(p.symbol("a")), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( @@ -126,7 +126,7 @@ public void doesNotFire() p.limit( 10, p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("foo")) .source(p.values(20, p.symbol("foo")))))) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java index 6297b39f9f0b..c4f2365eb69a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantPredicateAboveTableScan.java @@ -28,12 +28,12 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -43,15 +43,15 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.ir.IrExpressions.ifExpression; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.BuiltinFunctionCallBuilder.resolve; import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -99,7 +99,7 @@ public void consumesDeterministicPredicateIfNewDomainIsSame() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -118,9 +118,9 @@ public void consumesDeterministicPredicateIfNewDomainIsWider() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new LogicalExpression(OR, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 45L)))), + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 45L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -139,7 +139,7 @@ public void consumesDeterministicPredicateIfNewDomainIsNarrower() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 45L)), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 47L)))), + new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 45L)), new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 47L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -147,7 +147,7 @@ public void consumesDeterministicPredicateIfNewDomainIsNarrower() TupleDomain.withColumnDomains(ImmutableMap.of(columnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(44L, 45L, 46L))))))) .matches( filter( - new InPredicate(new SymbolReference(BIGINT, "nationkey"), ImmutableList.of(new Constant(BIGINT, 44L), new Constant(BIGINT, 45L))), + new In(new Reference(BIGINT, "nationkey"), ImmutableList.of(new Constant(BIGINT, 44L), new Constant(BIGINT, 45L))), constrainedTableScanWithTableLayout( "nation", ImmutableMap.of("nationkey", Domain.multipleValues(BIGINT, ImmutableList.of(44L, 45L, 46L))), @@ -160,22 +160,22 @@ public void doesNotConsumeRemainingPredicateIfNewDomainIsWider() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new LogicalExpression( + new Logical( AND, ImmutableList.of( - new ComparisonExpression( + new Comparison( EQUAL, resolve(tester().getMetadata()) .setName("rand") .build(), new Constant(BIGINT, 42L)), - new ComparisonExpression( + new Comparison( EQUAL, - new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), new Constant(BIGINT, 44L)), - LogicalExpression.or( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 45L))))), + Logical.or( + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 45L))))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -184,16 +184,16 @@ public void doesNotConsumeRemainingPredicateIfNewDomainIsWider() columnHandle, NullableValue.of(BIGINT, (long) 44)))))) .matches( filter( - LogicalExpression.and( - new ComparisonExpression( + Logical.and( + new Comparison( EQUAL, resolve(tester().getMetadata()) .setName("rand") .build(), new Constant(BIGINT, 42L)), - new ComparisonExpression( + new Comparison( EQUAL, - new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), + new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 17L)), new Constant(BIGINT, 44L))), constrainedTableScanWithTableLayout( "nation", @@ -207,7 +207,7 @@ public void doesNotFireOnNonDeterministicPredicate() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new ComparisonExpression( + new Comparison( EQUAL, resolve(tester().getMetadata()) .setName("rand") @@ -226,7 +226,7 @@ public void doesNotFireIfRuleNotChangePlan() { tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 17L)), new Constant(BIGINT, 44L)), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 15L)), new Constant(BIGINT, 43L)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 17L)), new Constant(BIGINT, 44L)), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 15L)), new Constant(BIGINT, 43L)))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -240,7 +240,7 @@ public void doesNotAddTableLayoutToFilterTableScan() { tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "orderstatus"), new Constant(VARCHAR, Slices.utf8Slice("F"))), + new Comparison(EQUAL, new Reference(VARCHAR, "orderstatus"), new Constant(VARCHAR, Slices.utf8Slice("F"))), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), @@ -254,7 +254,7 @@ public void doesNotFireOnNoTableScanPredicate() ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 3L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 0L)))), new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "nationkey"), new Constant(INTEGER, 1L)))))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 3L)), new Comparison(GREATER_THAN, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 0L)))), new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN, new Reference(INTEGER, "nationkey"), new Constant(INTEGER, 1L)))))), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), @@ -270,9 +270,9 @@ public void skipNotFullyExtractedConjunct() ColumnHandle nationKeyColumnHandle = new TpchColumnHandle("nationkey", BIGINT); tester().assertThat(removeRedundantPredicateAboveTableScan) .on(p -> p.filter( - new LogicalExpression(AND, ImmutableList.of( - ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "name"), new Constant(VARCHAR, Slices.utf8Slice("x"))), TRUE_LITERAL, FALSE_LITERAL), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)))), + new Logical(AND, ImmutableList.of( + ifExpression(new Comparison(EQUAL, new Reference(VARCHAR, "name"), new Constant(VARCHAR, Slices.utf8Slice("x"))), TRUE, FALSE), + new Comparison(EQUAL, new Reference(BIGINT, "nationkey"), new Constant(BIGINT, 44L)))), p.tableScan( nationTableHandle, ImmutableList.of( @@ -286,7 +286,7 @@ public void skipNotFullyExtractedConjunct() nationKeyColumnHandle, NullableValue.of(BIGINT, (long) 44)))))) .matches( filter( - ifExpression(new ComparisonExpression(EQUAL, new SymbolReference(VARCHAR, "name"), new Constant(VARCHAR, Slices.utf8Slice("x"))), TRUE_LITERAL, FALSE_LITERAL), + ifExpression(new Comparison(EQUAL, new Reference(VARCHAR, "name"), new Constant(VARCHAR, Slices.utf8Slice("x"))), TRUE, FALSE), constrainedTableScanWithTableLayout( "nation", ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantSort.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantSort.java index 7ea0f9054929..adb40aab6a87 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantSort.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantSort.java @@ -14,7 +14,7 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ValuesNode; @@ -35,7 +35,7 @@ public void test() p.sort( ImmutableList.of(p.symbol("c")), p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(p.values(p.symbol("foo")))))) .matches( @@ -62,7 +62,7 @@ public void doesNotFire() p.sort( ImmutableList.of(p.symbol("c")), p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("foo")) .source(p.values(20, p.symbol("foo")))))) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantTopN.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantTopN.java index 37ae21fc0095..02174ed1474c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantTopN.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantTopN.java @@ -15,9 +15,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.FilterNode; @@ -27,7 +27,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; import static io.trino.sql.planner.iterative.rule.test.PlanBuilder.aggregation; @@ -44,7 +44,7 @@ public void test() 10, ImmutableList.of(p.symbol("c")), p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .globalGrouping() .source(p.values(p.symbol("foo")))))) .matches( @@ -57,7 +57,7 @@ public void test() 10, ImmutableList.of(p.symbol("a")), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( @@ -79,7 +79,7 @@ public void testZeroTopN() 0, ImmutableList.of(p.symbol("a")), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( @@ -98,7 +98,7 @@ public void doesNotFire() 10, ImmutableList.of(p.symbol("c")), p.aggregation(builder -> builder - .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), aggregation("count", ImmutableList.of(new Reference(BIGINT, "foo"))), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("foo")) .source(p.values(20, p.symbol("foo")))))) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java index c28dfde66f1e..ea75ef42c6ad 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java @@ -14,16 +14,16 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; public class TestRemoveTrivialFilters @@ -34,7 +34,7 @@ public void testDoesNotFire() { tester().assertThat(new RemoveTrivialFilters()) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Constant(INTEGER, 1L)), p.values())) .doesNotFire(); } @@ -43,7 +43,7 @@ public void testDoesNotFire() public void testRemovesTrueFilter() { tester().assertThat(new RemoveTrivialFilters()) - .on(p -> p.filter(TRUE_LITERAL, p.values())) + .on(p -> p.filter(TRUE, p.values())) .matches(values()); } @@ -52,7 +52,7 @@ public void testRemovesFalseFilter() { tester().assertThat(new RemoveTrivialFilters()) .on(p -> p.filter( - FALSE_LITERAL, + FALSE, p.values( ImmutableList.of(p.symbol("a")), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L)))))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveUnreferencedScalarSubqueries.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveUnreferencedScalarSubqueries.java index febe1d16dda6..fd0d9d876430 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveUnreferencedScalarSubqueries.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveUnreferencedScalarSubqueries.java @@ -15,15 +15,15 @@ import com.google.common.collect.ImmutableList; import io.trino.spi.type.BigintType; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; import static io.trino.sql.planner.plan.JoinType.FULL; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -53,7 +53,7 @@ public void testRemoveUnreferencedInput() emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), LEFT, - TRUE_LITERAL, + TRUE, p.values(2, b)); }) .matches(values("b")); @@ -65,7 +65,7 @@ public void testRemoveUnreferencedInput() emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), RIGHT, - TRUE_LITERAL, + TRUE, p.values(2, b)); }) .matches(values("b")); @@ -77,7 +77,7 @@ public void testRemoveUnreferencedInput() emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), FULL, - TRUE_LITERAL, + TRUE, p.values(2, b)); }) .matches(values("b")); @@ -93,9 +93,9 @@ public void testDoNotRemoveInputOfLeftOrFullJoinWhenSubqueryPotentiallyEmpty() emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), LEFT, - TRUE_LITERAL, + TRUE, p.filter( - new ComparisonExpression( + new Comparison( LESS_THAN, b.toSymbolReference(), new Constant(INTEGER, 3L)), @@ -110,9 +110,9 @@ public void testDoNotRemoveInputOfLeftOrFullJoinWhenSubqueryPotentiallyEmpty() emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), FULL, - TRUE_LITERAL, + TRUE, p.filter( - new ComparisonExpression( + new Comparison( LESS_THAN, b.toSymbolReference(), new Constant(INTEGER, 3L)), @@ -151,7 +151,7 @@ public void testRemoveUnreferencedSubquery() emptyList(), p.values(p.symbol("b", BigintType.BIGINT)), LEFT, - TRUE_LITERAL, + TRUE, p.values(emptyList(), ImmutableList.of(emptyList())))) .matches(values("b")); @@ -160,7 +160,7 @@ public void testRemoveUnreferencedSubquery() emptyList(), p.values(p.symbol("b", BigintType.BIGINT)), RIGHT, - TRUE_LITERAL, + TRUE, p.values(emptyList(), ImmutableList.of(emptyList())))) .matches(values("b")); @@ -169,7 +169,7 @@ public void testRemoveUnreferencedSubquery() emptyList(), p.values(p.symbol("b", BigintType.BIGINT)), FULL, - TRUE_LITERAL, + TRUE, p.values(emptyList(), ImmutableList.of(emptyList())))) .matches(values("b")); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java index 868fc2700db8..96af650b4fe7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java @@ -20,9 +20,9 @@ import io.trino.cost.SymbolStatsEstimate; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.Type; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.Symbol; @@ -46,8 +46,8 @@ import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.AUTOMATIC; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.BROADCAST; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -346,7 +346,7 @@ public void testDoesNotFireForNonDeterministicFilter() ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), ImmutableList.of(p.symbol("A1")), ImmutableList.of(p.symbol("B1")), - Optional.of(new ComparisonExpression( + Optional.of(new Comparison( LESS_THAN, p.symbol("A1").toSymbolReference(), new TestingFunctionResolution().functionCallBuilder("random").build())))) @@ -387,7 +387,7 @@ public void testPredicatesPushedDown() new EquiJoinClause(p.symbol("B2"), p.symbol("C1"))), ImmutableList.of(p.symbol("A1")), ImmutableList.of(), - Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1").toSymbolReference(), p.symbol("B1").toSymbolReference())))) + Optional.of(new Comparison(EQUAL, p.symbol("A1").toSymbolReference(), p.symbol("B1").toSymbolReference())))) .matches( join(INNER, builder -> builder .equiCriteria("C1", "B2") @@ -421,7 +421,7 @@ public void testPushesProjectionsThroughJoin() INNER, p.project( Assignments.of( - p.symbol("P1"), new ArithmeticNegation(p.symbol("B1").toSymbolReference()), + p.symbol("P1"), new Negation(p.symbol("B1").toSymbolReference()), p.symbol("P2"), p.symbol("A1").toSymbolReference()), p.join( INNER, @@ -435,7 +435,7 @@ public void testPushesProjectionsThroughJoin() ImmutableList.of(new EquiJoinClause(p.symbol("P1"), p.symbol("C1"))), ImmutableList.of(p.symbol("P1")), ImmutableList.of(), - Optional.of(new ComparisonExpression(EQUAL, p.symbol("P2").toSymbolReference(), p.symbol("C1").toSymbolReference())))) + Optional.of(new Comparison(EQUAL, p.symbol("P2").toSymbolReference(), p.symbol("C1").toSymbolReference())))) .matches( join(INNER, builder -> builder .equiCriteria("C1", "P1") @@ -445,11 +445,11 @@ public void testPushesProjectionsThroughJoin() .equiCriteria("P2", "P1") .left( strictProject( - ImmutableMap.of("P2", expression(new SymbolReference(BIGINT, "A1"))), + ImmutableMap.of("P2", expression(new Reference(BIGINT, "A1"))), values("A1"))) .right( strictProject( - ImmutableMap.of("P1", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "B1")))), + ImmutableMap.of("P1", expression(new Negation(new Reference(BIGINT, "B1")))), values("B1"))))))); } @@ -475,7 +475,7 @@ public void testDoesNotPushProjectionThroughJoinIfTooExpensive() INNER, p.project( Assignments.of( - p.symbol("P1"), new ArithmeticNegation(p.symbol("B1").toSymbolReference())), + p.symbol("P1"), new Negation(p.symbol("B1").toSymbolReference())), p.join( INNER, p.values(new PlanNodeId("valuesA"), 2, p.symbol("A1")), @@ -495,7 +495,7 @@ public void testDoesNotPushProjectionThroughJoinIfTooExpensive() .left(values("C1")) .right( strictProject( - ImmutableMap.of("P1", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "B1")))), + ImmutableMap.of("P1", expression(new Negation(new Reference(BIGINT, "B1")))), join(INNER, rightJoinBuilder -> rightJoinBuilder .equiCriteria("A1", "B1") .left(values("A1")) @@ -536,7 +536,7 @@ public void testSmallerJoinFirst() new EquiJoinClause(p.symbol("B2"), p.symbol("C1"))), ImmutableList.of(p.symbol("A1")), ImmutableList.of(), - Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1").toSymbolReference(), p.symbol("B1").toSymbolReference())))) + Optional.of(new Comparison(EQUAL, p.symbol("A1").toSymbolReference(), p.symbol("B1").toSymbolReference())))) .matches( join(INNER, builder -> builder .equiCriteria("A1", "B1") diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java index fae39f6a676d..53386549b28c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java @@ -17,12 +17,12 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; import io.trino.spi.type.VarcharType; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.JoinNode.EquiJoinClause; @@ -35,7 +35,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -85,7 +85,7 @@ public void testDoesNotFireOnJoinWithCondition() INNER, p.values(1, p.symbol("a")), p.values(5, p.symbol("b")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))) + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))) .doesNotFire(); } @@ -134,7 +134,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() LEFT, p.values(1, p.symbol("a")), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values(10, p.symbol("b"))))) .doesNotFire(); @@ -143,7 +143,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.join( RIGHT, p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), p.values(10, p.symbol("a"))), p.values(1, p.symbol("b")))) .doesNotFire(); @@ -154,7 +154,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() FULL, p.values(1, p.symbol("a")), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 5L)), p.values(10, p.symbol("b"))))) .doesNotFire(); @@ -163,7 +163,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.join( FULL, p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), p.values(10, p.symbol("a"))), p.values(1, p.symbol("b")))) .doesNotFire(); @@ -183,7 +183,7 @@ public void testReplaceInnerJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); tester().assertThat(new ReplaceJoinOverConstantWithProject()) @@ -197,7 +197,7 @@ public void testReplaceInnerJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -215,7 +215,7 @@ public void testReplaceLeftJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); tester().assertThat(new ReplaceJoinOverConstantWithProject()) @@ -229,7 +229,7 @@ public void testReplaceLeftJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -247,7 +247,7 @@ public void testReplaceRightJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); tester().assertThat(new ReplaceJoinOverConstantWithProject()) @@ -261,7 +261,7 @@ public void testReplaceRightJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -279,7 +279,7 @@ public void testReplaceFullJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); tester().assertThat(new ReplaceJoinOverConstantWithProject()) @@ -293,7 +293,7 @@ public void testReplaceFullJoinWithProject() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -315,14 +315,14 @@ public void testRemoveOutputDuplicates() ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), - "c", PlanMatchPattern.expression(new SymbolReference(BIGINT, "c"))), + "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @Test public void testNonDeterministicValues() { - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); @@ -334,7 +334,7 @@ public void testNonDeterministicValues() p.values(5, p.symbol("b")))) .doesNotFire(); - FunctionCall uuidFunction = new FunctionCall( + Call uuidFunction = new Call( tester().getMetadata().resolveBuiltinFunction("uuid", ImmutableList.of()), ImmutableList.of()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java index 92e7a404a3f2..2bcd116780d1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithProject.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; @@ -89,7 +89,7 @@ public void testReplaceLeftJoin() .matches( project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), + "a", expression(new Reference(BIGINT, "a")), "b", expression(new Constant(BIGINT, null))), values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); } @@ -107,7 +107,7 @@ public void testReplaceRightJoin() project( ImmutableMap.of( "a", expression(new Constant(BIGINT, null)), - "b", expression(new SymbolReference(BIGINT, "b"))), + "b", expression(new Reference(BIGINT, "b"))), values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); } @@ -123,7 +123,7 @@ public void testReplaceFULLJoin() .matches( project( ImmutableMap.of( - "a", expression(new SymbolReference(BIGINT, "a")), + "a", expression(new Reference(BIGINT, "a")), "b", expression(new Constant(BIGINT, null))), values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); @@ -137,7 +137,7 @@ public void testReplaceFULLJoin() project( ImmutableMap.of( "a", expression(new Constant(BIGINT, null)), - "b", expression(new SymbolReference(BIGINT, "b"))), + "b", expression(new Reference(BIGINT, "b"))), values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java index 8b11866aec7a..512bf80dd532 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceRedundantJoinWithSource.java @@ -15,9 +15,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -26,7 +26,7 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -142,10 +142,10 @@ public void testReplaceInnerJoinWithFilter() INNER, p.values(10, p.symbol("a", BIGINT)), p.values(1), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)))) + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)))) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)), values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); tester().assertThat(new ReplaceRedundantJoinWithSource()) @@ -154,10 +154,10 @@ public void testReplaceInnerJoinWithFilter() INNER, p.values(1), p.values(10, p.symbol("b", BIGINT)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 0L)))) + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Constant(BIGINT, 0L)))) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Constant(BIGINT, 0L)), values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null)))))); } @@ -180,7 +180,7 @@ public void testReplaceLeftJoin() LEFT, p.values(10, p.symbol("a", BIGINT)), p.values(1), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 0L)))) + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L)))) .matches( values(ImmutableList.of("a"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null))))); } @@ -204,7 +204,7 @@ public void testReplaceRightJoin() RIGHT, p.values(1), p.values(10, p.symbol("b", BIGINT)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 0L)))) + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Constant(BIGINT, 0L)))) .matches( values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null))))); } @@ -237,7 +237,7 @@ public void testReplaceFullJoin() FULL, p.values(1), p.values(10, p.symbol("b", BIGINT)), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new Constant(BIGINT, 0L)))) + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Constant(BIGINT, 0L)))) .matches( values(ImmutableList.of("b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null))))); @@ -249,7 +249,7 @@ public void testReplaceFullJoin() p.join( FULL, p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 5L)), p.values(10, p.symbol("a", BIGINT))), p.values(1))) .doesNotFire(); @@ -263,7 +263,7 @@ public void testReplaceFullJoin() FULL, p.values(1), p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 5L)), p.values(10, p.symbol("a", BIGINT))))) .doesNotFire(); } @@ -282,11 +282,11 @@ public void testPruneOutputs() ImmutableList.of(), ImmutableList.of(a), ImmutableList.of(), - Optional.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))); + Optional.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))); }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values(ImmutableList.of("a", "b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null), new Constant(BIGINT, null)))))); tester().assertThat(new ReplaceRedundantJoinWithSource()) @@ -300,13 +300,13 @@ public void testPruneOutputs() ImmutableList.of(), ImmutableList.of(a), ImmutableList.of(), - Optional.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")))); + Optional.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")))); }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "b")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "b")), values(ImmutableList.of("a", "b"), nCopies(10, ImmutableList.of(new Constant(BIGINT, null), new Constant(BIGINT, null))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java index 3cce646b97ab..f4d987b34276 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -19,16 +19,16 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionRewriter; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.type.Reals; import io.trino.type.UnknownType; import io.trino.util.DateTimeUtils; @@ -50,21 +50,21 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.DIVIDE; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.ir.IrUtils.extractPredicates; import static io.trino.sql.ir.IrUtils.logicalExpression; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite; import static java.util.stream.Collectors.toList; @@ -80,214 +80,214 @@ public class TestSimplifyExpressions public void testPushesDownNegations() { assertSimplifies( - new NotExpression(new SymbolReference(BOOLEAN, "X")), - new NotExpression(new SymbolReference(BOOLEAN, "X"))); + new Not(new Reference(BOOLEAN, "X")), + new Not(new Reference(BOOLEAN, "X"))); assertSimplifies( - new NotExpression(new NotExpression(new SymbolReference(BOOLEAN, "X"))), - new SymbolReference(BOOLEAN, "X")); + new Not(new Not(new Reference(BOOLEAN, "X"))), + new Reference(BOOLEAN, "X")); assertSimplifies( - new NotExpression(new NotExpression(new NotExpression(new SymbolReference(BOOLEAN, "X")))), - new NotExpression(new SymbolReference(BOOLEAN, "X"))); + new Not(new Not(new Not(new Reference(BOOLEAN, "X")))), + new Not(new Reference(BOOLEAN, "X"))); assertSimplifies( - new NotExpression(new NotExpression(new NotExpression(new SymbolReference(BOOLEAN, "X")))), - new NotExpression(new SymbolReference(BOOLEAN, "X"))); + new Not(new Not(new Not(new Reference(BOOLEAN, "X")))), + new Not(new Reference(BOOLEAN, "X"))); assertSimplifies( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))); + new Not(new Comparison(GREATER_THAN, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))); assertSimplifies( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BOOLEAN, "X"), new NotExpression(new NotExpression(new SymbolReference(BOOLEAN, "Y"))))), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))); + new Not(new Comparison(GREATER_THAN, new Reference(BOOLEAN, "X"), new Not(new Not(new Reference(BOOLEAN, "Y"))))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))); assertSimplifies( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BOOLEAN, "X"), new NotExpression(new NotExpression(new SymbolReference(BOOLEAN, "Y")))), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))); + new Comparison(GREATER_THAN, new Reference(BOOLEAN, "X"), new Not(new Not(new Reference(BOOLEAN, "Y")))), + new Comparison(GREATER_THAN, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))); assertSimplifies( - new NotExpression(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new NotExpression(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V"))))))), - new LogicalExpression(OR, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "X")), new NotExpression(new SymbolReference(BOOLEAN, "Y")), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V")))))); + new Not(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Not(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V"))))))), + new Logical(OR, ImmutableList.of(new Not(new Reference(BOOLEAN, "X")), new Not(new Reference(BOOLEAN, "Y")), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V")))))); assertSimplifies( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new NotExpression(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V"))))))), - new LogicalExpression(AND, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "X")), new NotExpression(new SymbolReference(BOOLEAN, "Y")), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V")))))); + new Not(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Not(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V"))))))), + new Logical(AND, ImmutableList.of(new Not(new Reference(BOOLEAN, "X")), new Not(new Reference(BOOLEAN, "Y")), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V")))))); assertSimplifies( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V")))))), - new LogicalExpression(AND, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "X")), new NotExpression(new SymbolReference(BOOLEAN, "Y")), new LogicalExpression(AND, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "Z")), new NotExpression(new SymbolReference(BOOLEAN, "V"))))))); + new Not(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V")))))), + new Logical(AND, ImmutableList.of(new Not(new Reference(BOOLEAN, "X")), new Not(new Reference(BOOLEAN, "Y")), new Logical(AND, ImmutableList.of(new Not(new Reference(BOOLEAN, "Z")), new Not(new Reference(BOOLEAN, "V"))))))); assertSimplifies( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y")))); assertSimplifies( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y")))); assertSimplifies( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y")))); assertSimplifies( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y")))); } @Test public void testExtractCommonPredicates() { assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y")))); + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "X")))); + new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), + new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "X")))); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "X"))), - new SymbolReference(BOOLEAN, "X")); + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "X"))), + new Reference(BOOLEAN, "X")); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "X"))), - new SymbolReference(BOOLEAN, "X")); + new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "X"))), + new Reference(BOOLEAN, "X")); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))))), - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y")))); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))))), + new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "V"))), - new SymbolReference(BOOLEAN, "V")); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "V"))), + new Reference(BOOLEAN, "V")); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "V"))), - new SymbolReference(BOOLEAN, "V")); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "V"))), + new Reference(BOOLEAN, "V")); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))))), - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B")))); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))))), + new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"))))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B")))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"))))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B")))); assertSimplifies( - new ComparisonExpression(EQUAL, new SymbolReference(BOOLEAN, "I"), new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C")))))), - new ComparisonExpression(EQUAL, new SymbolReference(BOOLEAN, "I"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))))); + new Comparison(EQUAL, new Reference(BOOLEAN, "I"), new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C")))))), + new Comparison(EQUAL, new Reference(BOOLEAN, "I"), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))))); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Z"))))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Z")))))); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Z"))))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Z")))))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "V"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "Z"))))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "V"), new SymbolReference(BOOLEAN, "Z")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "V"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "Z"))))), + new Logical(AND, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "V"), new Reference(BOOLEAN, "Z")))))); assertSimplifies( - new ComparisonExpression(EQUAL, new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "V"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "Z"))))), new SymbolReference(BOOLEAN, "I")), - new ComparisonExpression(EQUAL, new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "V"), new SymbolReference(BOOLEAN, "Z"))))), new SymbolReference(BOOLEAN, "I"))); + new Comparison(EQUAL, new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "V"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "Z"))))), new Reference(BOOLEAN, "I")), + new Comparison(EQUAL, new Logical(OR, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "V"), new Reference(BOOLEAN, "Z"))))), new Reference(BOOLEAN, "I"))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "V"))), new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "V"))))), - new SymbolReference(BOOLEAN, "V")); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "V"))), new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "V"))))), + new Reference(BOOLEAN, "V")); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "X"))), new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "V"))))), - new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V")))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "X"))), new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "V"))))), + new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "Z"))), new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new SymbolReference(BOOLEAN, "V"))))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "V"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "Z"))), new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Reference(BOOLEAN, "V"))))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "V"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V")))))); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "Z"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "V"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "Y"), new SymbolReference(BOOLEAN, "X"))))))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "X"), new SymbolReference(BOOLEAN, "Y"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "Z"), new SymbolReference(BOOLEAN, "V"), new SymbolReference(BOOLEAN, "X")))))); + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "Z"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "V"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "Y"), new Reference(BOOLEAN, "X"))))))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "X"), new Reference(BOOLEAN, "Y"), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "Z"), new Reference(BOOLEAN, "V"), new Reference(BOOLEAN, "X")))))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "D"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "E"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "F"))))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "D"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "E"))), new SymbolReference(BOOLEAN, "F")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "D"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "E"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "F"))))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "D"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "E"))), new Reference(BOOLEAN, "F")))))); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "C"))))), new SymbolReference(BOOLEAN, "D"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"))), new SymbolReference(BOOLEAN, "D")))); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "C"))))), new Reference(BOOLEAN, "D"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"))), new Reference(BOOLEAN, "D")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "C"))))), new SymbolReference(BOOLEAN, "D"))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "D"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "D")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "C"))))), new Reference(BOOLEAN, "D"))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "D"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "D")))))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "C"))))), new SymbolReference(BOOLEAN, "D"))), new SymbolReference(BOOLEAN, "E"))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "E"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "E"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "D"), new SymbolReference(BOOLEAN, "E")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "C"))))), new Reference(BOOLEAN, "D"))), new Reference(BOOLEAN, "E"))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "E"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "E"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "D"), new Reference(BOOLEAN, "E")))))); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "C"))))), new SymbolReference(BOOLEAN, "D"))), new SymbolReference(BOOLEAN, "E"))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"))), new SymbolReference(BOOLEAN, "D"))), new SymbolReference(BOOLEAN, "E")))); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "C"))))), new Reference(BOOLEAN, "D"))), new Reference(BOOLEAN, "E"))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"))), new Reference(BOOLEAN, "D"))), new Reference(BOOLEAN, "E")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "D"))))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "C"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "D"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "C"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B"), new SymbolReference(BOOLEAN, "D")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "D"))))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "C"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "D"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "C"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B"), new Reference(BOOLEAN, "D")))))); // No distribution since it would add too many new terms assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "D"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "E"), new SymbolReference(BOOLEAN, "F"))))), - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A"), new SymbolReference(BOOLEAN, "B"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "C"), new SymbolReference(BOOLEAN, "D"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "E"), new SymbolReference(BOOLEAN, "F")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "D"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "E"), new Reference(BOOLEAN, "F"))))), + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A"), new Reference(BOOLEAN, "B"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "C"), new Reference(BOOLEAN, "D"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "E"), new Reference(BOOLEAN, "F")))))); // Test overflow handling for large disjunct expressions assertSimplifies( - new LogicalExpression(OR, ImmutableList.of( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A1"), new SymbolReference(BOOLEAN, "A2"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A3"), new SymbolReference(BOOLEAN, "A4"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A5"), new SymbolReference(BOOLEAN, "A6"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A7"), new SymbolReference(BOOLEAN, "A8"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A9"), new SymbolReference(BOOLEAN, "A10"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A11"), new SymbolReference(BOOLEAN, "A12"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A13"), new SymbolReference(BOOLEAN, "A14"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A15"), new SymbolReference(BOOLEAN, "A16"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A17"), new SymbolReference(BOOLEAN, "A18"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A19"), new SymbolReference(BOOLEAN, "A20"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A21"), new SymbolReference(BOOLEAN, "A22"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A23"), new SymbolReference(BOOLEAN, "A24"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A25"), new SymbolReference(BOOLEAN, "A26"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A27"), new SymbolReference(BOOLEAN, "A28"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A29"), new SymbolReference(BOOLEAN, "A30"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A31"), new SymbolReference(BOOLEAN, "A32"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A33"), new SymbolReference(BOOLEAN, "A34"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A35"), new SymbolReference(BOOLEAN, "A36"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A37"), new SymbolReference(BOOLEAN, "A38"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A39"), new SymbolReference(BOOLEAN, "A40"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A41"), new SymbolReference(BOOLEAN, "A42"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A43"), new SymbolReference(BOOLEAN, "A44"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A45"), new SymbolReference(BOOLEAN, "A46"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A47"), new SymbolReference(BOOLEAN, "A48"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A49"), new SymbolReference(BOOLEAN, "A50"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A51"), new SymbolReference(BOOLEAN, "A52"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A53"), new SymbolReference(BOOLEAN, "A54"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A55"), new SymbolReference(BOOLEAN, "A56"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A57"), new SymbolReference(BOOLEAN, "A58"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A59"), new SymbolReference(BOOLEAN, "A60"))))), - new LogicalExpression(OR, ImmutableList.of( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A1"), new SymbolReference(BOOLEAN, "A2"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A3"), new SymbolReference(BOOLEAN, "A4"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A5"), new SymbolReference(BOOLEAN, "A6"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A7"), new SymbolReference(BOOLEAN, "A8"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A9"), new SymbolReference(BOOLEAN, "A10"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A11"), new SymbolReference(BOOLEAN, "A12"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A13"), new SymbolReference(BOOLEAN, "A14"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A15"), new SymbolReference(BOOLEAN, "A16"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A17"), new SymbolReference(BOOLEAN, "A18"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A19"), new SymbolReference(BOOLEAN, "A20"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A21"), new SymbolReference(BOOLEAN, "A22"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A23"), new SymbolReference(BOOLEAN, "A24"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A25"), new SymbolReference(BOOLEAN, "A26"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A27"), new SymbolReference(BOOLEAN, "A28"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A29"), new SymbolReference(BOOLEAN, "A30"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A31"), new SymbolReference(BOOLEAN, "A32"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A33"), new SymbolReference(BOOLEAN, "A34"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A35"), new SymbolReference(BOOLEAN, "A36"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A37"), new SymbolReference(BOOLEAN, "A38"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A39"), new SymbolReference(BOOLEAN, "A40"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A41"), new SymbolReference(BOOLEAN, "A42"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A43"), new SymbolReference(BOOLEAN, "A44"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A45"), new SymbolReference(BOOLEAN, "A46"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A47"), new SymbolReference(BOOLEAN, "A48"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A49"), new SymbolReference(BOOLEAN, "A50"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A51"), new SymbolReference(BOOLEAN, "A52"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A53"), new SymbolReference(BOOLEAN, "A54"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A55"), new SymbolReference(BOOLEAN, "A56"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A57"), new SymbolReference(BOOLEAN, "A58"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "A59"), new SymbolReference(BOOLEAN, "A60")))))); + new Logical(OR, ImmutableList.of( + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A1"), new Reference(BOOLEAN, "A2"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A3"), new Reference(BOOLEAN, "A4"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A5"), new Reference(BOOLEAN, "A6"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A7"), new Reference(BOOLEAN, "A8"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A9"), new Reference(BOOLEAN, "A10"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A11"), new Reference(BOOLEAN, "A12"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A13"), new Reference(BOOLEAN, "A14"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A15"), new Reference(BOOLEAN, "A16"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A17"), new Reference(BOOLEAN, "A18"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A19"), new Reference(BOOLEAN, "A20"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A21"), new Reference(BOOLEAN, "A22"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A23"), new Reference(BOOLEAN, "A24"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A25"), new Reference(BOOLEAN, "A26"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A27"), new Reference(BOOLEAN, "A28"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A29"), new Reference(BOOLEAN, "A30"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A31"), new Reference(BOOLEAN, "A32"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A33"), new Reference(BOOLEAN, "A34"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A35"), new Reference(BOOLEAN, "A36"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A37"), new Reference(BOOLEAN, "A38"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A39"), new Reference(BOOLEAN, "A40"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A41"), new Reference(BOOLEAN, "A42"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A43"), new Reference(BOOLEAN, "A44"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A45"), new Reference(BOOLEAN, "A46"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A47"), new Reference(BOOLEAN, "A48"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A49"), new Reference(BOOLEAN, "A50"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A51"), new Reference(BOOLEAN, "A52"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A53"), new Reference(BOOLEAN, "A54"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A55"), new Reference(BOOLEAN, "A56"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A57"), new Reference(BOOLEAN, "A58"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A59"), new Reference(BOOLEAN, "A60"))))), + new Logical(OR, ImmutableList.of( + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A1"), new Reference(BOOLEAN, "A2"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A3"), new Reference(BOOLEAN, "A4"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A5"), new Reference(BOOLEAN, "A6"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A7"), new Reference(BOOLEAN, "A8"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A9"), new Reference(BOOLEAN, "A10"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A11"), new Reference(BOOLEAN, "A12"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A13"), new Reference(BOOLEAN, "A14"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A15"), new Reference(BOOLEAN, "A16"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A17"), new Reference(BOOLEAN, "A18"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A19"), new Reference(BOOLEAN, "A20"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A21"), new Reference(BOOLEAN, "A22"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A23"), new Reference(BOOLEAN, "A24"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A25"), new Reference(BOOLEAN, "A26"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A27"), new Reference(BOOLEAN, "A28"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A29"), new Reference(BOOLEAN, "A30"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A31"), new Reference(BOOLEAN, "A32"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A33"), new Reference(BOOLEAN, "A34"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A35"), new Reference(BOOLEAN, "A36"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A37"), new Reference(BOOLEAN, "A38"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A39"), new Reference(BOOLEAN, "A40"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A41"), new Reference(BOOLEAN, "A42"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A43"), new Reference(BOOLEAN, "A44"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A45"), new Reference(BOOLEAN, "A46"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A47"), new Reference(BOOLEAN, "A48"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A49"), new Reference(BOOLEAN, "A50"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A51"), new Reference(BOOLEAN, "A52"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A53"), new Reference(BOOLEAN, "A54"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A55"), new Reference(BOOLEAN, "A56"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A57"), new Reference(BOOLEAN, "A58"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "A59"), new Reference(BOOLEAN, "A60")))))); } @Test public void testMultipleNulls() { assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), FALSE_LITERAL)), - FALSE_LITERAL); + new Logical(AND, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), FALSE)), + FALSE); assertSimplifies( - new LogicalExpression(AND, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new SymbolReference(BOOLEAN, "B1"))), - new LogicalExpression(AND, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new SymbolReference(BOOLEAN, "B1")))); + new Logical(AND, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Reference(BOOLEAN, "B1"))), + new Logical(AND, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Reference(BOOLEAN, "B1")))); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), TRUE_LITERAL)), - TRUE_LITERAL); + new Logical(OR, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), TRUE)), + TRUE); assertSimplifies( - new LogicalExpression(OR, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new SymbolReference(BOOLEAN, "B1"))), - new LogicalExpression(OR, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new SymbolReference(BOOLEAN, "B1")))); + new Logical(OR, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null), new Reference(BOOLEAN, "B1"))), + new Logical(OR, ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Reference(BOOLEAN, "B1")))); } @Test @@ -309,8 +309,8 @@ public void testCastBigintToBoundedVarchar() new Cast(new Constant(BIGINT, -12300000000L), createVarcharType(3)), new Cast(new Constant(BIGINT, -12300000000L), createVarcharType(3))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(BIGINT, 12300000000L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12300000000"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(BIGINT, 12300000000L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12300000000")))); + new Comparison(EQUAL, new Cast(new Constant(BIGINT, 12300000000L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12300000000"))), + new Comparison(EQUAL, new Cast(new Constant(BIGINT, 12300000000L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12300000000")))); } @Test @@ -332,8 +332,8 @@ public void testCastIntegerToBoundedVarchar() new Cast(new Constant(INTEGER, 1234L), createVarcharType(3)), new Cast(new Constant(INTEGER, 1234L), createVarcharType(3))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(INTEGER, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(INTEGER, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234")))); + new Comparison(EQUAL, new Cast(new Constant(INTEGER, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234"))), + new Comparison(EQUAL, new Cast(new Constant(INTEGER, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234")))); } @Test @@ -355,8 +355,8 @@ public void testCastSmallintToBoundedVarchar() new Cast(new Constant(SMALLINT, -1234L), createVarcharType(3)), new Cast(new Constant(SMALLINT, -1234L), createVarcharType(3))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(SMALLINT, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(SMALLINT, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234")))); + new Comparison(EQUAL, new Cast(new Constant(SMALLINT, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234"))), + new Comparison(EQUAL, new Cast(new Constant(SMALLINT, 1234L), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1234")))); } @Test @@ -378,8 +378,8 @@ public void testCastTinyintToBoundedVarchar() new Cast(new Constant(TINYINT, -123L), createVarcharType(2)), new Cast(new Constant(TINYINT, -123L), createVarcharType(2))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(TINYINT, 123L), createVarcharType(2)), new Constant(VARCHAR, Slices.utf8Slice("123"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(TINYINT, 123L), createVarcharType(2)), new Constant(VARCHAR, Slices.utf8Slice("123")))); + new Comparison(EQUAL, new Cast(new Constant(TINYINT, 123L), createVarcharType(2)), new Constant(VARCHAR, Slices.utf8Slice("123"))), + new Comparison(EQUAL, new Cast(new Constant(TINYINT, 123L), createVarcharType(2)), new Constant(VARCHAR, Slices.utf8Slice("123")))); } @Test @@ -401,8 +401,8 @@ public void testCastShortDecimalToBoundedVarchar() new Cast(new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("-12.4"))), createVarcharType(3)), new Cast(new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("-12.4"))), createVarcharType(3))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("12.4"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12.4"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("12.4"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12.4")))); + new Comparison(EQUAL, new Cast(new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("12.4"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12.4"))), + new Comparison(EQUAL, new Cast(new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("12.4"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("12.4")))); } @Test @@ -424,8 +424,8 @@ public void testCastLongDecimalToBoundedVarchar() new Cast(new Constant(createDecimalType(19, 1), Decimals.valueOf(new BigDecimal("-100000000000000000.1"))), createVarcharType(3)), new Cast(new Constant(createDecimalType(19, 1), Decimals.valueOf(new BigDecimal("-100000000000000000.1"))), createVarcharType(3))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(createDecimalType(19, 1), Decimals.valueOf(new BigDecimal("100000000000000000.1"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("100000000000000000.1"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(createDecimalType(19, 1), Decimals.valueOf(new BigDecimal("100000000000000000.1"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("100000000000000000.1")))); + new Comparison(EQUAL, new Cast(new Constant(createDecimalType(19, 1), Decimals.valueOf(new BigDecimal("100000000000000000.1"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("100000000000000000.1"))), + new Comparison(EQUAL, new Cast(new Constant(createDecimalType(19, 1), Decimals.valueOf(new BigDecimal("100000000000000000.1"))), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("100000000000000000.1")))); } @Test @@ -439,7 +439,7 @@ public void testCastDoubleToBoundedVarchar() new Cast(new Constant(DOUBLE, -0.0), createVarcharType(4)), new Constant(createVarcharType(4), Slices.utf8Slice("-0E0"))); assertSimplifies( - new Cast(new ArithmeticBinaryExpression(DIVIDE_DOUBLE, DIVIDE, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 0.0)), createVarcharType(3)), + new Cast(new Arithmetic(DIVIDE_DOUBLE, DIVIDE, new Constant(DOUBLE, 0.0), new Constant(DOUBLE, 0.0)), createVarcharType(3)), new Constant(createVarcharType(3), Slices.utf8Slice("NaN"))); assertSimplifies( new Cast(new Constant(DOUBLE, Double.POSITIVE_INFINITY), createVarcharType(8)), @@ -465,8 +465,8 @@ public void testCastDoubleToBoundedVarchar() new Cast(new Constant(DOUBLE, Double.POSITIVE_INFINITY), createVarcharType(7)), new Cast(new Constant(DOUBLE, Double.POSITIVE_INFINITY), createVarcharType(7))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(DOUBLE, 1200.0), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(DOUBLE, 1200.0), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0")))); + new Comparison(EQUAL, new Cast(new Constant(DOUBLE, 1200.0), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0"))), + new Comparison(EQUAL, new Cast(new Constant(DOUBLE, 1200.0), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0")))); } @Test @@ -480,7 +480,7 @@ public void testCastRealToBoundedVarchar() new Cast(new Constant(REAL, Reals.toReal(-0.0f)), createVarcharType(4)), new Constant(createVarcharType(4), Slices.utf8Slice("-0E0"))); assertSimplifies( - new Cast(new ArithmeticBinaryExpression(DIVIDE_REAL, DIVIDE, new Constant(REAL, Reals.toReal(0.0f)), new Constant(REAL, Reals.toReal(0.0f))), createVarcharType(3)), + new Cast(new Arithmetic(DIVIDE_REAL, DIVIDE, new Constant(REAL, Reals.toReal(0.0f)), new Constant(REAL, Reals.toReal(0.0f))), createVarcharType(3)), new Constant(createVarcharType(3), Slices.utf8Slice("NaN"))); assertSimplifies( new Cast(new Constant(REAL, Reals.toReal(Float.POSITIVE_INFINITY)), createVarcharType(8)), @@ -506,8 +506,8 @@ public void testCastRealToBoundedVarchar() new Cast(new Constant(REAL, Reals.toReal(Float.POSITIVE_INFINITY)), createVarcharType(7)), new Cast(new Constant(REAL, Reals.toReal(Float.POSITIVE_INFINITY)), createVarcharType(7))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(REAL, Reals.toReal(12e2f)), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(REAL, Reals.toReal(12e2f)), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0")))); + new Comparison(EQUAL, new Cast(new Constant(REAL, Reals.toReal(12e2f)), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0"))), + new Comparison(EQUAL, new Cast(new Constant(REAL, Reals.toReal(12e2f)), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("1200.0")))); } @Test @@ -526,8 +526,8 @@ public void testCastDateToBoundedVarchar() new Cast(new Constant(DATE, (long) DateTimeUtils.parseDate("2013-02-02")), createVarcharType(3)), new Cast(new Constant(DATE, (long) DateTimeUtils.parseDate("2013-02-02")), createVarcharType(3))); assertSimplifies( - new ComparisonExpression(EQUAL, new Cast(new Constant(DATE, (long) DateTimeUtils.parseDate("2013-02-02")), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("2013-02-02"))), - new ComparisonExpression(EQUAL, new Cast(new Constant(DATE, (long) DateTimeUtils.parseDate("2013-02-02")), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("2013-02-02")))); + new Comparison(EQUAL, new Cast(new Constant(DATE, (long) DateTimeUtils.parseDate("2013-02-02")), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("2013-02-02"))), + new Comparison(EQUAL, new Cast(new Constant(DATE, (long) DateTimeUtils.parseDate("2013-02-02")), createVarcharType(3)), new Constant(VARCHAR, Slices.utf8Slice("2013-02-02")))); } private static void assertSimplifies(Expression expression, Expression expected) @@ -540,118 +540,118 @@ private static void assertSimplifies(Expression expression, Expression expected) public void testPushesDownNegationsNumericTypes() { assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))), - new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))); + new Not(new Comparison(EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))), + new Comparison(NOT_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))); + new Not(new Comparison(GREATER_THAN, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))); assertSimplifiesNumericTypes( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "I3"), new SymbolReference(INTEGER, "I4"))))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(INTEGER, "I3"), new SymbolReference(INTEGER, "I4"))))); + new Not(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Comparison(GREATER_THAN, new Reference(INTEGER, "I3"), new Reference(INTEGER, "I4"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Comparison(LESS_THAN_OR_EQUAL, new Reference(INTEGER, "I3"), new Reference(INTEGER, "I4"))))); assertSimplifiesNumericTypes( - new NotExpression(new NotExpression(new NotExpression(new LogicalExpression(OR, ImmutableList.of(new NotExpression(new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")))), new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "I3"), new SymbolReference(INTEGER, "I4")))))))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "I3"), new SymbolReference(INTEGER, "I4"))))); + new Not(new Not(new Not(new Logical(OR, ImmutableList.of(new Not(new Not(new Comparison(EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")))), new Not(new Comparison(GREATER_THAN, new Reference(INTEGER, "I3"), new Reference(INTEGER, "I4")))))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Comparison(GREATER_THAN, new Reference(INTEGER, "I3"), new Reference(INTEGER, "I4"))))); assertSimplifiesNumericTypes( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "B1"), new SymbolReference(BOOLEAN, "B2"))), new NotExpression(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B3"), new SymbolReference(BOOLEAN, "B4"))))))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new LogicalExpression(OR, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "B1")), new NotExpression(new SymbolReference(BOOLEAN, "B2")))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B3"), new SymbolReference(BOOLEAN, "B4")))))); + new Not(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "B1"), new Reference(BOOLEAN, "B2"))), new Not(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B3"), new Reference(BOOLEAN, "B4"))))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Logical(OR, ImmutableList.of(new Not(new Reference(BOOLEAN, "B1")), new Not(new Reference(BOOLEAN, "B2")))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B3"), new Reference(BOOLEAN, "B4")))))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")))); /* Restricted rewrite for types having NaN */ assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))); + new Not(new Comparison(EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Comparison(NOT_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(NOT_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new ComparisonExpression(EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))); + new Not(new Comparison(NOT_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Comparison(EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new ComparisonExpression(NOT_EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))); + new Not(new Comparison(EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Comparison(NOT_EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2"))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(NOT_EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new ComparisonExpression(EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))); + new Not(new Comparison(NOT_EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Comparison(EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2"))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))); + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Not(new Comparison(IS_DISTINCT_FROM, new Reference(REAL, "R1"), new Reference(REAL, "R2")))); // DOUBLE: no negation pushdown for inequalities assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))); + new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new NotExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))); + new Not(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Not(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))); + new Not(new Comparison(LESS_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Not(new Comparison(LESS_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), - new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))); + new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), + new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))); // REAL: no negation pushdown for inequalities assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))); + new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new NotExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))); + new Not(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Not(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))); + new Not(new Comparison(LESS_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Not(new Comparison(LESS_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), - new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))); + new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), + new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(REAL, "R1"), new Reference(REAL, "R2")))); // Multiple negations assertSimplifiesNumericTypes( - new NotExpression(new NotExpression(new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))))), - new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))); + new Not(new Not(new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))))), + new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))); assertSimplifiesNumericTypes( - new NotExpression(new NotExpression(new NotExpression(new NotExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))))), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))); + new Not(new Not(new Not(new Not(new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))))), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))); assertSimplifiesNumericTypes( - new NotExpression(new NotExpression(new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))); + new Not(new Not(new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2"))))), + new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2")))); assertSimplifiesNumericTypes( - new NotExpression(new NotExpression(new NotExpression(new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))))), - new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))); + new Not(new Not(new Not(new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2")))))), + new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2"))); // Nested comparisons assertSimplifiesNumericTypes( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")))))); + new Not(new Logical(OR, ImmutableList.of(new Comparison(EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")))))); assertSimplifiesNumericTypes( - new NotExpression(new NotExpression(new NotExpression(new LogicalExpression(OR, ImmutableList.of(new NotExpression(new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))), new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")))))))), - new LogicalExpression(AND, ImmutableList.of(new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))))); + new Not(new Not(new Not(new Logical(OR, ImmutableList.of(new Not(new Not(new Comparison(LESS_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2")))), new Not(new Comparison(GREATER_THAN, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")))))))), + new Logical(AND, ImmutableList.of(new Not(new Comparison(LESS_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2"))), new Comparison(GREATER_THAN, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))))); assertSimplifiesNumericTypes( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "B1"), new SymbolReference(BOOLEAN, "B2"))), new NotExpression(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B3"), new SymbolReference(BOOLEAN, "B4"))))))), - new LogicalExpression(AND, ImmutableList.of(new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), new LogicalExpression(OR, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "B1")), new NotExpression(new SymbolReference(BOOLEAN, "B2")))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "B3"), new SymbolReference(BOOLEAN, "B4")))))); + new Not(new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "B1"), new Reference(BOOLEAN, "B2"))), new Not(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B3"), new Reference(BOOLEAN, "B4"))))))), + new Logical(AND, ImmutableList.of(new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), new Logical(OR, ImmutableList.of(new Not(new Reference(BOOLEAN, "B1")), new Not(new Reference(BOOLEAN, "B2")))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "B3"), new Reference(BOOLEAN, "B4")))))); assertSimplifiesNumericTypes( - new NotExpression(new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")))), new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "B1"), new SymbolReference(BOOLEAN, "B2"))), new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2"))))))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")))), new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new NotExpression(new SymbolReference(BOOLEAN, "B1")), new NotExpression(new SymbolReference(BOOLEAN, "B2")))), new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new SymbolReference(REAL, "R2")))))))); + new Not(new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")), new Comparison(LESS_THAN, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")))), new Logical(AND, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "B1"), new Reference(BOOLEAN, "B2"))), new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2"))))))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")))), new Logical(OR, ImmutableList.of(new Logical(OR, ImmutableList.of(new Not(new Reference(BOOLEAN, "B1")), new Not(new Reference(BOOLEAN, "B2")))), new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Reference(REAL, "R2")))))))); assertSimplifiesNumericTypes( - ifExpression(new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2"))), new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2")), - ifExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "I1"), new SymbolReference(INTEGER, "I2")), new SymbolReference(DOUBLE, "D1"), new SymbolReference(DOUBLE, "D2"))); + ifExpression(new Not(new Comparison(LESS_THAN, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2"))), new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2")), + ifExpression(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "I1"), new Reference(INTEGER, "I2")), new Reference(DOUBLE, "D1"), new Reference(DOUBLE, "D2"))); // Symbol of type having NaN on either side of comparison assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new Constant(DOUBLE, 1.0))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "D1"), new Constant(DOUBLE, 1.0)))); + new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Constant(DOUBLE, 1.0))), + new Not(new Comparison(GREATER_THAN, new Reference(DOUBLE, "D1"), new Constant(DOUBLE, 1.0)))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new Constant(DOUBLE, 1.0), new SymbolReference(DOUBLE, "D2"))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new Constant(DOUBLE, 1.0), new SymbolReference(DOUBLE, "D2")))); + new Not(new Comparison(GREATER_THAN, new Constant(DOUBLE, 1.0), new Reference(DOUBLE, "D2"))), + new Not(new Comparison(GREATER_THAN, new Constant(DOUBLE, 1.0), new Reference(DOUBLE, "D2")))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new Constant(REAL, Reals.toReal(1L)))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(REAL, "R1"), new Constant(REAL, Reals.toReal(1L))))); + new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Constant(REAL, Reals.toReal(1L)))), + new Not(new Comparison(GREATER_THAN, new Reference(REAL, "R1"), new Constant(REAL, Reals.toReal(1L))))); assertSimplifiesNumericTypes( - new NotExpression(new ComparisonExpression(GREATER_THAN, new Constant(REAL, Reals.toReal(1)), new SymbolReference(REAL, "R2"))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new Constant(REAL, Reals.toReal(1)), new SymbolReference(REAL, "R2")))); + new Not(new Comparison(GREATER_THAN, new Constant(REAL, Reals.toReal(1)), new Reference(REAL, "R2"))), + new Not(new Comparison(GREATER_THAN, new Constant(REAL, Reals.toReal(1)), new Reference(REAL, "R2")))); } private static void assertSimplifiesNumericTypes(Expression expression, Expression expected) @@ -669,13 +669,13 @@ private static class NormalizeExpressionRewriter extends ExpressionRewriter { @Override - public Expression rewriteLogicalExpression(LogicalExpression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteLogical(Logical node, Void context, ExpressionTreeRewriter treeRewriter) { - List predicates = extractPredicates(node.getOperator(), node).stream() + List predicates = extractPredicates(node.operator(), node).stream() .map(p -> treeRewriter.rewrite(p, context)) .sorted(Comparator.comparing(Expression::toString)) .collect(toList()); - return logicalExpression(node.getOperator(), predicates); + return logicalExpression(node.operator(), predicates); } @Override @@ -683,7 +683,7 @@ public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), TRUE_LITERAL, FALSE_LITERAL), + ifExpression(new Reference(BOOLEAN, "a"), TRUE, FALSE), p.values(p.symbol("a")))) .matches( filter( - new SymbolReference(BOOLEAN, "a"), + new Reference(BOOLEAN, "a"), values("a"))); // true result iff the condition is true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), TRUE_LITERAL, new Constant(UnknownType.UNKNOWN, null)), + ifExpression(new Reference(BOOLEAN, "a"), TRUE, new Constant(UnknownType.UNKNOWN, null)), p.values(p.symbol("a")))) .matches( filter( - new SymbolReference(BOOLEAN, "a"), + new Reference(BOOLEAN, "a"), values("a"))); // true result iff the condition is null or false tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), FALSE_LITERAL, TRUE_LITERAL), + ifExpression(new Reference(BOOLEAN, "a"), FALSE, TRUE), p.values(p.symbol("a")))) .matches( filter( - new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(BOOLEAN, "a")), new NotExpression(new SymbolReference(BOOLEAN, "a")))), + new Logical(OR, ImmutableList.of(new IsNull(new Reference(BOOLEAN, "a")), new Not(new Reference(BOOLEAN, "a")))), values("a"))); // true result iff the condition is null or false tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), new Constant(UnknownType.UNKNOWN, null), TRUE_LITERAL), + ifExpression(new Reference(BOOLEAN, "a"), new Constant(UnknownType.UNKNOWN, null), TRUE), p.values(p.symbol("a")))) .matches( filter( - new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new SymbolReference(BOOLEAN, "a")), new NotExpression(new SymbolReference(BOOLEAN, "a")))), + new Logical(OR, ImmutableList.of(new IsNull(new Reference(BOOLEAN, "a")), new Not(new Reference(BOOLEAN, "a")))), values("a"))); // always true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), TRUE_LITERAL, TRUE_LITERAL), + ifExpression(new Reference(BOOLEAN, "a"), TRUE, TRUE), p.values(p.symbol("a")))) .matches( filter( - TRUE_LITERAL, + TRUE, values("a"))); // always false tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), FALSE_LITERAL, FALSE_LITERAL), + ifExpression(new Reference(BOOLEAN, "a"), FALSE, FALSE), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // both results equal tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L))), + ifExpression(new Reference(BOOLEAN, "a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), values("a", "b"))); // both results are equal non-deterministic expressions - FunctionCall randomFunction = new FunctionCall( + Call randomFunction = new Call( tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( ifExpression( - new SymbolReference(BOOLEAN, "a"), - new ComparisonExpression(EQUAL, randomFunction, new Constant(INTEGER, 0L)), - new ComparisonExpression(EQUAL, randomFunction, new Constant(INTEGER, 0L))), + new Reference(BOOLEAN, "a"), + new Comparison(EQUAL, randomFunction, new Constant(INTEGER, 0L)), + new Comparison(EQUAL, randomFunction, new Constant(INTEGER, 0L))), p.values(p.symbol("a")))) .doesNotFire(); // always null (including the default) -> simplified to FALSE tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), new Constant(UnknownType.UNKNOWN, null)), + ifExpression(new Reference(BOOLEAN, "a"), new Constant(UnknownType.UNKNOWN, null)), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // condition is true -> first branch tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(TRUE_LITERAL, new SymbolReference(BOOLEAN, "a"), new NotExpression(new SymbolReference(BOOLEAN, "a"))), + ifExpression(TRUE, new Reference(BOOLEAN, "a"), new Not(new Reference(BOOLEAN, "a"))), p.values(p.symbol("a")))) .matches( filter( - new SymbolReference(BOOLEAN, "a"), + new Reference(BOOLEAN, "a"), values("a"))); // condition is true -> second branch tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a"), new NotExpression(new SymbolReference(BOOLEAN, "a"))), + ifExpression(FALSE, new Reference(BOOLEAN, "a"), new Not(new Reference(BOOLEAN, "a"))), p.values(p.symbol("a")))) .matches( filter( - new NotExpression(new SymbolReference(BOOLEAN, "a")), + new Not(new Reference(BOOLEAN, "a")), values("a"))); // condition is true, no second branch -> the result is null, simplified to FALSE tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a")), + ifExpression(FALSE, new Reference(BOOLEAN, "a")), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // not known result (`b`) - cannot optimize tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - ifExpression(new SymbolReference(BOOLEAN, "a"), TRUE_LITERAL, new SymbolReference(BOOLEAN, "b")), + ifExpression(new Reference(BOOLEAN, "a"), TRUE, new Reference(BOOLEAN, "b")), p.values(p.symbol("a"), p.symbol("b")))) .doesNotFire(); } @@ -195,15 +195,15 @@ public void testSimplifyNullIfExpression() // NULLIF(x, y) returns true if and only if: x != y AND x = true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new NullIfExpression(new SymbolReference(BOOLEAN, "a"), new SymbolReference(BOOLEAN, "b")), + new NullIf(new Reference(BOOLEAN, "a"), new Reference(BOOLEAN, "b")), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new LogicalExpression(AND, ImmutableList.of( - new SymbolReference(BOOLEAN, "a"), - new LogicalExpression(OR, ImmutableList.of( - new IsNullPredicate(new SymbolReference(BOOLEAN, "b")), - new NotExpression(new SymbolReference(BOOLEAN, "b")))))), + new Logical(AND, ImmutableList.of( + new Reference(BOOLEAN, "a"), + new Logical(OR, ImmutableList.of( + new IsNull(new Reference(BOOLEAN, "b")), + new Not(new Reference(BOOLEAN, "b")))))), values("a", "b"))); } @@ -212,188 +212,188 @@ public void testSimplifySearchedCaseExpression() { tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL)), - Optional.of(FALSE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE)), + Optional.of(FALSE)), p.values(p.symbol("a")))) .doesNotFire(); // all results true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL)), - Optional.of(TRUE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE)), + Optional.of(TRUE)), p.values(p.symbol("a")))) .matches( filter( - TRUE_LITERAL, + TRUE, values("a"))); // all results not true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL)), - Optional.of(FALSE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), + Optional.of(FALSE)), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // all results not true (including default null result) tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), Optional.empty()), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // one result true, and remaining results not true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL)), - Optional.of(FALSE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE)), + Optional.of(FALSE)), p.values(p.symbol("a")))) .matches( filter( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))), new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))))), new LogicalExpression(OR, ImmutableList.of(new IsNullPredicate(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))), new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))))), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new IsNull(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))), new Not(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))))), new Logical(OR, ImmutableList.of(new IsNull(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))), new Not(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))))), new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)))), values("a"))); // first result true, and remaining results not true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL)), - Optional.of(FALSE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), + Optional.of(FALSE)), p.values(p.symbol("a")))) .matches( filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), values("a"))); // all results not true, and default true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), - new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE_LITERAL)), - Optional.of(TRUE_LITERAL)), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), + new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(UnknownType.UNKNOWN, null)), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), + Optional.of(TRUE)), p.values(p.symbol("a")))) .matches( filter( - new LogicalExpression(AND, ImmutableList.of( - new LogicalExpression(OR, ImmutableList.of( - new IsNullPredicate(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))), - new NotExpression(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))))), - new LogicalExpression(OR, ImmutableList.of( - new IsNullPredicate(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))), - new NotExpression(new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))))), - new LogicalExpression(OR, ImmutableList.of( - new IsNullPredicate(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))), - new NotExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 0L))))))), + new Logical(AND, ImmutableList.of( + new Logical(OR, ImmutableList.of( + new IsNull(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))), + new Not(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))))), + new Logical(OR, ImmutableList.of( + new IsNull(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))), + new Not(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))))), + new Logical(OR, ImmutableList.of( + new IsNull(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))), + new Not(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L))))))), values("a"))); // all conditions not true - return the default tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a")), - new WhenClause(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a")), - new WhenClause(new Constant(UnknownType.UNKNOWN, null), new SymbolReference(BOOLEAN, "a"))), - Optional.of(new SymbolReference(BOOLEAN, "b"))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Reference(BOOLEAN, "a")), + new WhenClause(FALSE, new Reference(BOOLEAN, "a")), + new WhenClause(new Constant(UnknownType.UNKNOWN, null), new Reference(BOOLEAN, "a"))), + Optional.of(new Reference(BOOLEAN, "b"))), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new SymbolReference(BOOLEAN, "b"), + new Reference(BOOLEAN, "b"), values("a", "b"))); // all conditions not true, no default specified - return false tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a")), - new WhenClause(FALSE_LITERAL, new NotExpression(new SymbolReference(BOOLEAN, "a"))), - new WhenClause(new Constant(UnknownType.UNKNOWN, null), new SymbolReference(BOOLEAN, "a"))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Reference(BOOLEAN, "a")), + new WhenClause(FALSE, new Not(new Reference(BOOLEAN, "a"))), + new WhenClause(new Constant(UnknownType.UNKNOWN, null), new Reference(BOOLEAN, "a"))), Optional.empty()), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // not true conditions preceding true condition - return the result associated with the true condition tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a")), - new WhenClause(new Constant(UnknownType.UNKNOWN, null), new NotExpression(new SymbolReference(BOOLEAN, "a"))), - new WhenClause(TRUE_LITERAL, new SymbolReference(BOOLEAN, "b"))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Reference(BOOLEAN, "a")), + new WhenClause(new Constant(UnknownType.UNKNOWN, null), new Not(new Reference(BOOLEAN, "a"))), + new WhenClause(TRUE, new Reference(BOOLEAN, "b"))), Optional.empty()), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new SymbolReference(BOOLEAN, "b"), + new Reference(BOOLEAN, "b"), values("a", "b"))); // remove not true condition and move the result associated with the first true condition to default tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(FALSE_LITERAL, new SymbolReference(BOOLEAN, "a")), - new WhenClause(new SymbolReference(BOOLEAN, "b"), new NotExpression(new SymbolReference(BOOLEAN, "a"))), - new WhenClause(TRUE_LITERAL, new SymbolReference(BOOLEAN, "b"))), + new Case(ImmutableList.of( + new WhenClause(FALSE, new Reference(BOOLEAN, "a")), + new WhenClause(new Reference(BOOLEAN, "b"), new Not(new Reference(BOOLEAN, "a"))), + new WhenClause(TRUE, new Reference(BOOLEAN, "b"))), Optional.empty()), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new SearchedCaseExpression(ImmutableList.of(new WhenClause(new SymbolReference(BOOLEAN, "b"), new NotExpression(new SymbolReference(BOOLEAN, "a")))), Optional.of(new SymbolReference(BOOLEAN, "b"))), + new Case(ImmutableList.of(new WhenClause(new Reference(BOOLEAN, "b"), new Not(new Reference(BOOLEAN, "a")))), Optional.of(new Reference(BOOLEAN, "b"))), values("a", "b"))); // move the result associated with the first true condition to default tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new SymbolReference(BOOLEAN, "a")), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new NotExpression(new SymbolReference(BOOLEAN, "a"))), - new WhenClause(TRUE_LITERAL, new SymbolReference(BOOLEAN, "b")), - new WhenClause(TRUE_LITERAL, new NotExpression(new SymbolReference(BOOLEAN, "b")))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Reference(BOOLEAN, "a")), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Not(new Reference(BOOLEAN, "a"))), + new WhenClause(TRUE, new Reference(BOOLEAN, "b")), + new WhenClause(TRUE, new Not(new Reference(BOOLEAN, "b")))), Optional.empty()), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new SymbolReference(BOOLEAN, "a")), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new NotExpression(new SymbolReference(BOOLEAN, "a")))), - Optional.of(new SymbolReference(BOOLEAN, "b"))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Reference(BOOLEAN, "a")), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Not(new Reference(BOOLEAN, "a")))), + Optional.of(new Reference(BOOLEAN, "b"))), values("a", "b"))); // cannot remove any clause tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SearchedCaseExpression(ImmutableList.of( - new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new SymbolReference(BOOLEAN, "a")), - new WhenClause(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 0L)), new NotExpression(new SymbolReference(BOOLEAN, "a")))), - Optional.of(new SymbolReference(BOOLEAN, "b"))), + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Reference(BOOLEAN, "a")), + new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Not(new Reference(BOOLEAN, "a")))), + Optional.of(new Reference(BOOLEAN, "b"))), p.values(p.symbol("a"), p.symbol("b")))) .doesNotFire(); } @@ -403,88 +403,88 @@ public void testSimplifySimpleCaseExpression() { tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "a"), + new Switch( + new Reference(BOOLEAN, "a"), ImmutableList.of( - new WhenClause(new SymbolReference(BOOLEAN, "b"), TRUE_LITERAL), - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE_LITERAL)), - Optional.of(TRUE_LITERAL)), + new WhenClause(new Reference(BOOLEAN, "b"), TRUE), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE)), + Optional.of(TRUE)), p.values(p.symbol("a"), p.symbol("b")))) .doesNotFire(); // comparison with null returns null - no WHEN branch matches, return default value tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SimpleCaseExpression( + new Switch( new Constant(UnknownType.UNKNOWN, null), ImmutableList.of( - new WhenClause(new Constant(UnknownType.UNKNOWN, null), TRUE_LITERAL), - new WhenClause(new SymbolReference(BOOLEAN, "a"), FALSE_LITERAL)), - Optional.of(new SymbolReference(BOOLEAN, "b"))), + new WhenClause(new Constant(UnknownType.UNKNOWN, null), TRUE), + new WhenClause(new Reference(BOOLEAN, "a"), FALSE)), + Optional.of(new Reference(BOOLEAN, "b"))), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new SymbolReference(BOOLEAN, "b"), + new Reference(BOOLEAN, "b"), values("a", "b"))); // comparison with null returns null - no WHEN branch matches, the result is default null, simplified to FALSE tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SimpleCaseExpression( + new Switch( new Constant(UnknownType.UNKNOWN, null), ImmutableList.of( - new WhenClause(new Constant(UnknownType.UNKNOWN, null), TRUE_LITERAL), - new WhenClause(new SymbolReference(BOOLEAN, "a"), FALSE_LITERAL)), + new WhenClause(new Constant(UnknownType.UNKNOWN, null), TRUE), + new WhenClause(new Reference(BOOLEAN, "a"), FALSE)), Optional.empty()), p.values(p.symbol("a")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a"))); // all results true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "a"), + new Switch( + new Reference(BOOLEAN, "a"), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 1L)), TRUE_LITERAL), - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 2L)), TRUE_LITERAL)), - Optional.of(TRUE_LITERAL)), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), TRUE), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), TRUE)), + Optional.of(TRUE)), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - TRUE_LITERAL, + TRUE, values("a", "b"))); // all results not true tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "a"), + new Switch( + new Reference(BOOLEAN, "a"), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE_LITERAL), - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 2L)), new Constant(UnknownType.UNKNOWN, null))), - Optional.of(FALSE_LITERAL)), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), new Constant(UnknownType.UNKNOWN, null))), + Optional.of(FALSE)), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a", "b"))); // all results not true (including default null result) tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( - new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "a"), + new Switch( + new Reference(BOOLEAN, "a"), ImmutableList.of( - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE_LITERAL), - new WhenClause(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 2L)), new Constant(UnknownType.UNKNOWN, null))), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 1L)), FALSE), + new WhenClause(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), new Constant(UnknownType.UNKNOWN, null))), Optional.empty()), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - FALSE_LITERAL, + FALSE, values("a", "b"))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java index ccace6cd92c2..351131c16d66 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java @@ -15,9 +15,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.AggregationFunction; import io.trino.sql.planner.assertions.ExpectedValueProvider; @@ -30,7 +30,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.globalAggregation; @@ -48,7 +48,7 @@ public void testNoDistinct() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), @@ -62,8 +62,8 @@ public void testMultipleDistincts() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), @@ -77,8 +77,8 @@ public void testMixedDistinctAndNonDistinct() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), @@ -92,13 +92,13 @@ public void testDistinctWithFilter() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) .source( p.project( Assignments.builder() .putIdentity(p.symbol("input1")) .putIdentity(p.symbol("input2")) - .put(p.symbol("filter1"), new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "input2"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter1"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) .build(), p.values( p.symbol("input1"), @@ -112,7 +112,7 @@ public void testSingleAggregation() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .matches( @@ -137,8 +137,8 @@ public void testMultipleAggregations() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", true, ImmutableList.of(new SymbolReference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .matches( @@ -164,8 +164,8 @@ public void testMultipleInputs() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("corr", true, ImmutableList.of(new SymbolReference(BIGINT, "x"), new SymbolReference(BIGINT, "y"))), ImmutableList.of(BIGINT, BIGINT)) - .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("corr", true, ImmutableList.of(new SymbolReference(BIGINT, "y"), new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT, BIGINT)) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("corr", true, ImmutableList.of(new Reference(BIGINT, "x"), new Reference(BIGINT, "y"))), ImmutableList.of(BIGINT, BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("corr", true, ImmutableList.of(new Reference(BIGINT, "y"), new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT, BIGINT)) .source( p.values(p.symbol("x", BIGINT), p.symbol("y", BIGINT))))) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 0794d28e3a23..c3488b3db5e9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.ExpectedValueProvider; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.DataOrganizationSpecification; @@ -81,12 +81,12 @@ public void subsetComesFirst() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new SymbolReference(BIGINT, "a")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "a")), DEFAULT_FRAME, false)), p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new SymbolReference(BIGINT, "b")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "b")), DEFAULT_FRAME, false)), p.values(p.symbol("a"), p.symbol("b"))))) .matches( window(windowMatcherBuilder -> windowMatcherBuilder @@ -107,12 +107,12 @@ public void dependentWindowsAreNotReordered() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1"), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new SymbolReference(DOUBLE, "avg_2")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(DOUBLE, "avg_2")), DEFAULT_FRAME, false)), p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2"), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new SymbolReference(BIGINT, "a")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "a")), DEFAULT_FRAME, false)), p.values(p.symbol("a"), p.symbol("b"))))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java index d32e53a69135..1adb23d5801b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java @@ -18,10 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.JoinType; @@ -31,9 +31,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -85,31 +85,31 @@ public void rewritesOnSubqueryWithDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.project( - Assignments.of(p.symbol("x"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 100L))), + Assignments.of(p.symbol("x"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 100L))), p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b")))))))) .matches( project(ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "x", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 100L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "x", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a"), new Constant(INTEGER, 100L)))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(), Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right(filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithoutProjection.java index c10860c58a24..c803febbf5b3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithoutProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithoutProjection.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.JoinType; import org.junit.jupiter.api.Test; @@ -24,8 +24,8 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -72,28 +72,28 @@ public void rewritesOnSubqueryWithDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), JoinType.LEFT, - TRUE_LITERAL, + TRUE, p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b"))))))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "a", expression(new SymbolReference(BIGINT, "a"))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "a", expression(new Reference(BIGINT, "a"))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(), Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java index c47b4083354b..9a9292409222 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java @@ -18,11 +18,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; @@ -33,12 +33,12 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; @@ -97,7 +97,7 @@ public void doesNotFireOnCorrelatedWithNonScalarAggregation() p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("b"))))) .doesNotFire(); } @@ -110,12 +110,12 @@ public void doesNotFireOnMultipleProjections() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("expr_2"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L))), p.project( - Assignments.of(p.symbol("expr"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping()))))) .doesNotFire(); } @@ -129,7 +129,7 @@ public void doesNotFireOnSubqueryWithoutProjection() p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping()))) .doesNotFire(); } @@ -141,18 +141,18 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.project(Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping())))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "expr", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_1"), new Constant(INTEGER, 1L)))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "expr", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_1"), new Constant(INTEGER, 1L)))), aggregation(ImmutableMap.of("sum_1", aggregationFunction("sum", ImmutableList.of("a"))), join(LEFT, builder -> builder .left(assignUniqueId("unique", values(ImmutableMap.of("corr", 0)))) - .right(project(ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + .right(project(ImmutableMap.of("non_null", expression(TRUE)), values(ImmutableMap.of("a", 0, "b", 1)))))))); } @@ -165,22 +165,22 @@ public void rewritesOnSubqueryWithDistinct() p.values(p.symbol("corr")), p.project( Assignments.of( - p.symbol("expr_sum"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), + p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), p.aggregation(outerBuilder -> outerBuilder - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .globalGrouping() .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b")))))))))) .matches( project(ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "expr_sum", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), + "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -194,16 +194,16 @@ public void rewritesOnSubqueryWithDistinct() Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( project( - ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + ImmutableMap.of("non_null", expression(TRUE)), filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))))); } @@ -218,22 +218,22 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() p.values(p.symbol("corr")), p.project( Assignments.of( - p.symbol("expr_sum"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), + p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), p.aggregation(outerBuilder -> outerBuilder - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .globalGrouping() .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b")))))))))) .matches( project(ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "expr_sum", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), + "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -242,21 +242,21 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( project( - ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + ImmutableMap.of("non_null", expression(TRUE)), aggregation( singleGroupingSet("a", "b"), ImmutableMap.of(), Optional.empty(), SINGLE, filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))))); } @@ -267,13 +267,13 @@ public void testWithPreexistingMask() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.project(Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("mask"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT), p.symbol("mask")) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT), p.symbol("mask")) .globalGrouping())))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "expr", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_1"), new Constant(INTEGER, 1L)))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "expr", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_1"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum_1"), aggregationFunction("sum", ImmutableList.of("a"))), @@ -282,11 +282,11 @@ public void testWithPreexistingMask() Optional.empty(), SINGLE, project( - ImmutableMap.of("new_mask", expression(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "mask"), new SymbolReference(BOOLEAN, "non_null"))))), + ImmutableMap.of("new_mask", expression(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "mask"), new Reference(BOOLEAN, "non_null"))))), join(LEFT, builder -> builder .left(assignUniqueId("unique", values(ImmutableMap.of("corr", 0)))) - .right(project(ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + .right(project(ImmutableMap.of("non_null", expression(TRUE)), values(ImmutableMap.of("a", 0, "mask", 1))))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java index cb5fd3775b00..aa3d462d8e18 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java @@ -18,11 +18,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; @@ -33,12 +33,12 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; @@ -97,7 +97,7 @@ public void doesNotFireOnCorrelatedWithNonScalarAggregation() p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("b"))))) .doesNotFire(); } @@ -110,12 +110,12 @@ public void doesNotFireOnMultipleProjections() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("expr_2"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "expr"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L))), p.project( - Assignments.of(p.symbol("expr"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping()))))) .doesNotFire(); } @@ -129,15 +129,15 @@ public void rewritesOnSubqueryWithoutProjection() p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping()))) .matches( - project(ImmutableMap.of("sum_1", expression(new SymbolReference(BIGINT, "sum_1")), "corr", expression(new SymbolReference(BIGINT, "corr"))), + project(ImmutableMap.of("sum_1", expression(new Reference(BIGINT, "sum_1")), "corr", expression(new Reference(BIGINT, "corr"))), aggregation(ImmutableMap.of("sum_1", aggregationFunction("sum", ImmutableList.of("a"))), join(LEFT, builder -> builder .left(assignUniqueId("unique", values(ImmutableMap.of("corr", 0)))) - .right(project(ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + .right(project(ImmutableMap.of("non_null", expression(TRUE)), values(ImmutableMap.of("a", 0, "b", 1)))))))); } @@ -148,10 +148,10 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.project(Assignments.of(p.symbol("expr"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping())))) .doesNotFire(); } @@ -166,7 +166,7 @@ public void testSubqueryWithCount() p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("count_rows"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) - .addAggregation(p.symbol("count_non_null_values"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("count_non_null_values"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping()))) .matches( project( @@ -176,7 +176,7 @@ public void testSubqueryWithCount() join(LEFT, builder -> builder .left(assignUniqueId("unique", values(ImmutableMap.of("corr", 0)))) - .right(project(ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + .right(project(ImmutableMap.of("non_null", expression(TRUE)), values(ImmutableMap.of("a", 0, "b", 1)))))))); } @@ -188,16 +188,16 @@ public void rewritesOnSubqueryWithDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.aggregation(outerBuilder -> outerBuilder - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .globalGrouping() .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b"))))))))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "sum_agg", expression(new SymbolReference(BIGINT, "sum_agg")), "count_agg", expression(new SymbolReference(BIGINT, "count_agg"))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "sum_agg", expression(new Reference(BIGINT, "sum_agg")), "count_agg", expression(new Reference(BIGINT, "count_agg"))), aggregation( singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -211,16 +211,16 @@ public void rewritesOnSubqueryWithDistinct() Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( project( - ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + ImmutableMap.of("non_null", expression(TRUE)), filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))))); } @@ -234,16 +234,16 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.aggregation(outerBuilder -> outerBuilder - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .globalGrouping() .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b"))))))))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "sum_agg", expression(new SymbolReference(BIGINT, "sum_agg")), "count_agg", expression(new SymbolReference(BIGINT, "count_agg"))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "sum_agg", expression(new Reference(BIGINT, "sum_agg")), "count_agg", expression(new Reference(BIGINT, "count_agg"))), aggregation( singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -252,21 +252,21 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() Optional.empty(), SINGLE, join(LEFT, builder -> builder - .filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( project( - ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + ImmutableMap.of("non_null", expression(TRUE)), aggregation( singleGroupingSet("a", "b"), ImmutableMap.of(), Optional.empty(), SINGLE, filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))))); } @@ -279,7 +279,7 @@ public void testWithPreexistingMask() p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("mask", BOOLEAN))) - .addAggregation(p.symbol("count_non_null_values"), PlanBuilder.aggregation("count", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT), p.symbol("mask", BOOLEAN)) + .addAggregation(p.symbol("count_non_null_values"), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT), p.symbol("mask", BOOLEAN)) .globalGrouping()))) .matches( project( @@ -291,11 +291,11 @@ public void testWithPreexistingMask() Optional.empty(), SINGLE, project( - ImmutableMap.of("new_mask", expression(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "mask"), new SymbolReference(BOOLEAN, "non_null"))))), + ImmutableMap.of("new_mask", expression(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "mask"), new Reference(BOOLEAN, "non_null"))))), join(LEFT, builder -> builder .left(assignUniqueId("unique", values(ImmutableMap.of("corr", 0)))) - .right(project(ImmutableMap.of("non_null", expression(TRUE_LITERAL)), + .right(project(ImmutableMap.of("non_null", expression(TRUE)), values(ImmutableMap.of("a", 0, "mask", 1))))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java index 0ed9490d6fa8..5ebd751c797a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java @@ -18,10 +18,10 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; @@ -32,11 +32,11 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; @@ -78,7 +78,7 @@ public void doesNotFireOnCorrelatedWithNonGroupedAggregation() Assignments.identity(p.symbol("sum")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping())))) .doesNotFire(); } @@ -91,37 +91,37 @@ public void rewritesOnSubqueryWithoutDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), + p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b")))))))) .matches( project(ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "expr_sum", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), + "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), Optional.empty(), SINGLE, join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))); } @@ -133,25 +133,25 @@ public void rewritesOnSubqueryWithDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), + p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b")))))))))) .matches( project(ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "expr_sum", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), + "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -163,14 +163,14 @@ public void rewritesOnSubqueryWithDistinct() Optional.empty(), SINGLE, join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( filter( - TRUE_LITERAL, + TRUE, values("a", "b")))))))); } @@ -184,32 +184,32 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum"), new Constant(INTEGER, 1L)), - p.symbol("expr_count"), new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count"), new Constant(INTEGER, 1L))), + p.symbol("expr_sum"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)), + p.symbol("expr_count"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count"), new Constant(INTEGER, 1L))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b")))))))))) .matches( project(ImmutableMap.of( - "corr", expression(new SymbolReference(BIGINT, "corr")), - "expr_sum", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), - "expr_count", expression(new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new SymbolReference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), + "corr", expression(new Reference(BIGINT, "corr")), + "expr_sum", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "sum_agg"), new Constant(INTEGER, 1L))), + "expr_count", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "count_agg"), new Constant(INTEGER, 1L)))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), Optional.empty(), SINGLE, join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", @@ -221,7 +221,7 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() Optional.empty(), SINGLE, filter( - TRUE_LITERAL, + TRUE, values("a", "b")))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java index 946a67ce5787..571364c5e659 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.JoinType; @@ -25,9 +25,9 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; @@ -63,7 +63,7 @@ public void doesNotFireOnCorrelatedWithNonGroupedAggregation() p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .globalGrouping()))) .doesNotFire(); } @@ -76,30 +76,30 @@ public void rewritesOnSubqueryWithoutDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b"))))))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "sum_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "sum_agg")), "count_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "count_agg"))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "sum_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "sum_agg")), "count_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "count_agg"))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), Optional.empty(), SINGLE, join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( filter( - TRUE_LITERAL, + TRUE, values("a", "b"))))))); } @@ -111,18 +111,18 @@ public void rewritesOnSubqueryWithDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b"))))))))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "sum_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "sum_agg")), "count_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "count_agg"))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "sum_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "sum_agg")), "count_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "count_agg"))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), @@ -134,14 +134,14 @@ public void rewritesOnSubqueryWithDistinct() Optional.empty(), SINGLE, join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", values("corr"))) .right( filter( - TRUE_LITERAL, + TRUE, values("a", "b")))))))); } @@ -155,25 +155,25 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "a"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) .source(p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr")), + new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr")), p.values(p.symbol("a"), p.symbol("b"))))))))) .matches( - project(ImmutableMap.of("corr", expression(new SymbolReference(BIGINT, "corr")), "sum_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "sum_agg")), "count_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new SymbolReference(BIGINT, "count_agg"))), + project(ImmutableMap.of("corr", expression(new Reference(BIGINT, "corr")), "sum_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "sum_agg")), "count_agg", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Reference(BIGINT, "count_agg"))), aggregation( singleGroupingSet("corr", "unique", "a"), ImmutableMap.of(Optional.of("sum_agg"), aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), aggregationFunction("count", ImmutableList.of())), Optional.empty(), SINGLE, join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "corr"))) + .filter(new Comparison(EQUAL, new Reference(BIGINT, "b"), new Reference(BIGINT, "corr"))) .left( assignUniqueId( "unique", @@ -185,7 +185,7 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() Optional.empty(), SINGLE, filter( - TRUE_LITERAL, + TRUE, values("a", "b")))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedJoinToJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedJoinToJoin.java index 0be5f09697dd..4401bc963323 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedJoinToJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedJoinToJoin.java @@ -14,10 +14,10 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.JoinType; @@ -25,11 +25,11 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -50,7 +50,7 @@ public void testRewriteInnerCorrelatedJoin() ImmutableList.of(a), p.values(a), p.filter( - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -58,11 +58,11 @@ public void testRewriteInnerCorrelatedJoin() }) .matches( join(JoinType.INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a"))) .left(values("a")) .right( filter( - TRUE_LITERAL, + TRUE, values("b"))))); tester().assertThat(new TransformCorrelatedJoinToJoin(tester().getPlannerContext())) @@ -73,12 +73,12 @@ public void testRewriteInnerCorrelatedJoin() ImmutableList.of(a), p.values(a), INNER, - new ComparisonExpression( + new Comparison( LESS_THAN, b.toSymbolReference(), new Constant(INTEGER, 3L)), p.filter( - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -86,11 +86,11 @@ public void testRewriteInnerCorrelatedJoin() }) .matches( join(JoinType.INNER, builder -> builder - .filter(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 3L))))) + .filter(new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")), new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 3L))))) .left(values("a")) .right( filter( - TRUE_LITERAL, + TRUE, values("b"))))); } @@ -105,9 +105,9 @@ public void testRewriteLeftCorrelatedJoin() ImmutableList.of(a), p.values(a), LEFT, - TRUE_LITERAL, + TRUE, p.filter( - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -115,11 +115,11 @@ public void testRewriteLeftCorrelatedJoin() }) .matches( join(JoinType.LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a"))) .left(values("a")) .right( filter( - TRUE_LITERAL, + TRUE, values("b"))))); tester().assertThat(new TransformCorrelatedJoinToJoin(tester().getPlannerContext())) @@ -130,12 +130,12 @@ public void testRewriteLeftCorrelatedJoin() ImmutableList.of(a), p.values(a), LEFT, - new ComparisonExpression( + new Comparison( LESS_THAN, b.toSymbolReference(), new Constant(INTEGER, 3L)), p.filter( - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -143,11 +143,11 @@ public void testRewriteLeftCorrelatedJoin() }) .matches( join(JoinType.LEFT, builder -> builder - .filter(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")), new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 3L))))) + .filter(new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")), new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 3L))))) .left(values("a")) .right( filter( - TRUE_LITERAL, + TRUE, values("b"))))); } @@ -159,10 +159,10 @@ public void doesNotFireForEnforceSingleRow() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), INNER, - TRUE_LITERAL, + TRUE, p.enforceSingleRow( p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "corr"), new SymbolReference(BIGINT, "a")), + new Comparison(EQUAL, new Reference(BIGINT, "corr"), new Reference(BIGINT, "a")), p.values(p.symbol("a")))))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index 4507b835f068..67bbb84c080a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -19,13 +19,13 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SimpleCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -41,10 +41,10 @@ import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; import static io.trino.sql.planner.assertions.PlanMatchPattern.correlatedJoin; @@ -107,7 +107,7 @@ public void rewritesOnSubqueryWithoutProjection() p.values(p.symbol("corr")), p.enforceSingleRow( p.filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))) .matches( project( @@ -122,7 +122,7 @@ public void rewritesOnSubqueryWithoutProjection() "unique", values("corr")), filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), values("a"))))))); } @@ -135,9 +135,9 @@ public void rewritesOnSubqueryWithProjection() p.values(p.symbol("corr")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 2L))), + Assignments.of(p.symbol("a2"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L))), p.filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS)))))) .matches( project( @@ -151,9 +151,9 @@ public void rewritesOnSubqueryWithProjection() assignUniqueId( "unique", values("corr")), - project(ImmutableMap.of("a2", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 2L)))), + project(ImmutableMap.of("a2", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), values("a")))))))); } @@ -165,12 +165,12 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("a3"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a2"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("a3"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L))), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 2L))), + Assignments.of(p.symbol("a2"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L))), p.filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))))) .matches( project( @@ -185,11 +185,11 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() "unique", values("corr")), project( - ImmutableMap.of("a3", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "a2"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("a3", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L)))), project( - ImmutableMap.of("a2", expression(new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("a2", expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), values("a"))))))))); } @@ -202,26 +202,26 @@ public void rewritesScalarSubquery() p.values(p.symbol("corr")), // make sure INNER correlated join is transformed to LEFT join if subplan could produce 0 rows INNER, - TRUE_LITERAL, + TRUE, p.enforceSingleRow( p.filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), ONE_ROW))))) .matches( correlatedJoin( ImmutableList.of("corr"), values("corr"), filter( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 1L), new SymbolReference(INTEGER, "a")), + new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), values("a"))) .with(CorrelatedJoinNode.class, join -> join.getType() == LEFT)); } private Expression ensureScalarSubquery() { - return new SimpleCaseExpression( - new SymbolReference(BOOLEAN, "is_distinct"), - ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), + return new Switch( + new Reference(BOOLEAN, "is_distinct"), + ImmutableList.of(new WhenClause(TRUE, TRUE)), Optional.of(new Cast( failFunction(tester().getMetadata(), SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), BOOLEAN))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 90e978cf44fe..b67432bcb041 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -21,11 +21,11 @@ import io.trino.plugin.tpch.TpchColumnHandle; import io.trino.spi.function.OperatorType; import io.trino.spi.type.VarcharType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -38,7 +38,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -70,15 +70,15 @@ public void testRewrite() ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", BIGINT))), p.project( - Assignments.of(p.symbol("l_expr2"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("l_expr2"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L))), p.values( ImmutableList.of(), ImmutableList.of( ImmutableList.of()))))) .matches(project( ImmutableMap.of( - "l_expr2", PlanMatchPattern.expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L))), - "l_nationkey", PlanMatchPattern.expression(new SymbolReference(BIGINT, "l_nationkey"))), + "l_expr2", PlanMatchPattern.expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L))), + "l_nationkey", PlanMatchPattern.expression(new Reference(BIGINT, "l_nationkey"))), tableScan("nation", ImmutableMap.of("l_nationkey", "nationkey")))); } @@ -103,11 +103,11 @@ public void testCorrelatedValues() return p.correlatedJoin( ImmutableList.of(a), p.values(3, a), - p.values(ImmutableList.of(a), ImmutableList.of(ImmutableList.of(new SymbolReference(BIGINT, "a"))))); + p.values(ImmutableList.of(a), ImmutableList.of(ImmutableList.of(new Reference(BIGINT, "a"))))); }) .matches( project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a"))), values(ImmutableList.of("a"), ImmutableList.of( ImmutableList.of(new Constant(BIGINT, null)), ImmutableList.of(new Constant(BIGINT, null)), @@ -121,13 +121,13 @@ public void testCorrelatedValues() return p.correlatedJoin( ImmutableList.of(a), p.values(3, a, b), - p.values(ImmutableList.of(a, c), ImmutableList.of(ImmutableList.of(new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 1L))))); + p.values(ImmutableList.of(a, c), ImmutableList.of(ImmutableList.of(new Reference(BIGINT, "a"), new Constant(BIGINT, 1L))))); }) .matches( project( ImmutableMap.of( - "a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), - "b", PlanMatchPattern.expression(new SymbolReference(BIGINT, "b")), + "a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), + "b", PlanMatchPattern.expression(new Reference(BIGINT, "b")), "c", PlanMatchPattern.expression(new Constant(BIGINT, 1L))), values(ImmutableList.of("a", "b"), ImmutableList.of( ImmutableList.of(new Constant(BIGINT, null), new Constant(BIGINT, null)), @@ -150,7 +150,7 @@ public void testUncorrelatedValues() .matches( project( ImmutableMap.of( - "a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), + "a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), "b", PlanMatchPattern.expression(new Constant(BIGINT, null))), values(ImmutableList.of("a"), ImmutableList.of( ImmutableList.of(new Constant(BIGINT, null)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformExistsApplyToCorrelatedJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformExistsApplyToCorrelatedJoin.java index 4a3d83909de3..ef7327f3aa02 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformExistsApplyToCorrelatedJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformExistsApplyToCorrelatedJoin.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.CoalesceExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.ApplyNode; @@ -28,10 +28,10 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.correlatedJoin; @@ -72,7 +72,7 @@ public void testRewrite() ImmutableList.of(), values(ImmutableMap.of()), project( - ImmutableMap.of("b", PlanMatchPattern.expression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "count_expr"), new Constant(BIGINT, 0L)))), + ImmutableMap.of("b", PlanMatchPattern.expression(new Comparison(GREATER_THAN, new Reference(BIGINT, "count_expr"), new Constant(BIGINT, 0L)))), aggregation(ImmutableMap.of("count_expr", aggregationFunction("count", ImmutableList.of())), values())))); } @@ -88,15 +88,15 @@ public void testRewritesToLimit() p.values(p.symbol("corr")), p.project(Assignments.of(), p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "corr"), new SymbolReference(BIGINT, "column")), + new Comparison(EQUAL, new Reference(BIGINT, "corr"), new Reference(BIGINT, "column")), p.values(p.symbol("column")))))) .matches( - project(ImmutableMap.of("b", PlanMatchPattern.expression(new CoalesceExpression(new SymbolReference(BOOLEAN, "subquerytrue"), FALSE_LITERAL))), + project(ImmutableMap.of("b", PlanMatchPattern.expression(new Coalesce(new Reference(BOOLEAN, "subquerytrue"), FALSE))), correlatedJoin( ImmutableList.of("corr"), values("corr"), project( - ImmutableMap.of("subquerytrue", PlanMatchPattern.expression(TRUE_LITERAL)), + ImmutableMap.of("subquerytrue", PlanMatchPattern.expression(TRUE)), limit(1, project(ImmutableMap.of(), node(FilterNode.class, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformFilteringSemiJoinToInnerJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformFilteringSemiJoinToInnerJoin.java index 2e54375ae00e..5108566fd711 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformFilteringSemiJoinToInnerJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformFilteringSemiJoinToInnerJoin.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -30,9 +30,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -53,7 +53,7 @@ public void testTransformSemiJoinToInnerJoin() Symbol b = p.symbol("b", BIGINT); Symbol aInB = p.symbol("a_in_b", BOOLEAN); return p.filter( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "a_in_b"), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 5L)))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "a_in_b"), new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 5L)))), p.semiJoin( p.values(a), p.values(b), @@ -66,11 +66,11 @@ public void testTransformSemiJoinToInnerJoin() }) .matches(project( ImmutableMap.of( - "a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), - "a_in_b", PlanMatchPattern.expression(TRUE_LITERAL)), + "a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), + "a_in_b", PlanMatchPattern.expression(TRUE)), join(INNER, builder -> builder .equiCriteria("a", "b") - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "a"), new Constant(BIGINT, 5L))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "a"), new Constant(BIGINT, 5L))) .left(values("a")) .right( aggregation( @@ -90,7 +90,7 @@ public void testRemoveRedundantFilter() Symbol b = p.symbol("b", BIGINT); Symbol aInB = p.symbol("a_in_b", BOOLEAN); return p.filter( - new SymbolReference(BOOLEAN, "a_in_b"), + new Reference(BOOLEAN, "a_in_b"), p.semiJoin( p.values(a), p.values(b), @@ -102,7 +102,7 @@ public void testRemoveRedundantFilter() Optional.empty())); }) .matches(project( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(BIGINT, "a")), "a_in_b", PlanMatchPattern.expression(TRUE_LITERAL)), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(BIGINT, "a")), "a_in_b", PlanMatchPattern.expression(TRUE)), join(INNER, builder -> builder .equiCriteria("a", "b") .left(values("a")) @@ -124,7 +124,7 @@ public void testFilterNotMatching() Symbol b = p.symbol("b"); Symbol aInB = p.symbol("a_in_b"); return p.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), p.semiJoin( p.values(a), p.values(b), @@ -147,7 +147,7 @@ public void testDoNotRewriteInContextOfDelete() Symbol b = p.symbol("b"); Symbol aInB = p.symbol("a_in_b"); return p.filter( - new SymbolReference(BOOLEAN, "a_in_b"), + new Reference(BOOLEAN, "a_in_b"), p.semiJoin( p.tableScan( ImmutableList.of(a), @@ -169,10 +169,10 @@ public void testDoNotRewriteInContextOfDelete() Symbol c = p.symbol("c"); Symbol aInB = p.symbol("a_in_b"); return p.filter( - new SymbolReference(BOOLEAN, "a_in_b"), + new Reference(BOOLEAN, "a_in_b"), p.semiJoin( p.project( - Assignments.of(a, new SymbolReference(BIGINT, "c")), + Assignments.of(a, new Reference(BIGINT, "c")), p.tableScan( ImmutableList.of(c), true)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java index b47fcd807f14..4e29e8e35bc8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java @@ -15,17 +15,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.JoinType; import org.junit.jupiter.api.Test; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -51,7 +51,7 @@ public void testRewriteLeftCorrelatedJoinWithScalarSubquery() emptyList(), p.values(a), LEFT, - TRUE_LITERAL, + TRUE, p.values(1, b)); }) .matches( @@ -71,7 +71,7 @@ public void testRewriteInnerCorrelatedJoin() emptyList(), p.values(a), LEFT, - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -79,7 +79,7 @@ public void testRewriteInnerCorrelatedJoin() }) .matches( join(JoinType.LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a"))) .left(values("a")) .right(values("b")))); } @@ -95,7 +95,7 @@ public void testRewriteLeftCorrelatedJoin() emptyList(), p.values(a), LEFT, - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -103,7 +103,7 @@ public void testRewriteLeftCorrelatedJoin() }) .matches( join(JoinType.LEFT, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a"))) + .filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a"))) .left(values("a")) .right(values("b")))); } @@ -119,7 +119,7 @@ public void testRewriteRightCorrelatedJoin() emptyList(), p.values(a), RIGHT, - TRUE_LITERAL, + TRUE, p.values(b)); }) .matches( @@ -135,7 +135,7 @@ public void testRewriteRightCorrelatedJoin() emptyList(), p.values(a), RIGHT, - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), @@ -144,8 +144,8 @@ public void testRewriteRightCorrelatedJoin() .matches( project( ImmutableMap.of( - "a", expression(ifExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b"), new SymbolReference(BIGINT, "a")), new SymbolReference(BIGINT, "a"), new Constant(BIGINT, null))), - "b", expression(new SymbolReference(BIGINT, "b"))), + "a", expression(ifExpression(new Comparison(GREATER_THAN, new Reference(BIGINT, "b"), new Reference(BIGINT, "a")), new Reference(BIGINT, "a"), new Constant(BIGINT, null))), + "b", expression(new Reference(BIGINT, "b"))), join(JoinType.INNER, builder -> builder .left(values("a")) .right(values("b"))))); @@ -162,7 +162,7 @@ public void testRewriteFullCorrelatedJoin() emptyList(), p.values(a), FULL, - TRUE_LITERAL, + TRUE, p.values(b)); }) .matches( @@ -178,7 +178,7 @@ public void testRewriteFullCorrelatedJoin() emptyList(), p.values(a), FULL, - new ComparisonExpression( + new Comparison( GREATER_THAN, b.toSymbolReference(), a.toSymbolReference()), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java index 20dcfff19a13..6f5f284a9063 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java @@ -18,7 +18,7 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Row; -import io.trino.sql.ir.SubscriptExpression; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -41,26 +41,26 @@ public class TestUnwrapRowSubscript @Test public void testSimpleSubscript() { - test(new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); - test(new SubscriptExpression(INTEGER, new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); - test(new SubscriptExpression(INTEGER, new SubscriptExpression(anonymousRow(INTEGER), new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new Constant(INTEGER, 2L)); + test(new Subscript(INTEGER, new Row(ImmutableList.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); + test(new Subscript(INTEGER, new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); + test(new Subscript(INTEGER, new Subscript(anonymousRow(INTEGER), new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new Constant(INTEGER, 2L)); } @Test public void testWithCast() { test( - new SubscriptExpression(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT))), new Constant(INTEGER, 1L)), + new Subscript(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT))), new Constant(INTEGER, 1L)), new Cast(new Constant(INTEGER, 1L), BIGINT)); test( - new SubscriptExpression(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT)), new Constant(INTEGER, 1L)), + new Subscript(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT)), new Constant(INTEGER, 1L)), new Cast(new Constant(INTEGER, 1L), BIGINT)); test( - new SubscriptExpression( + new Subscript( BIGINT, - new Cast(new SubscriptExpression( + new Cast(new Subscript( anonymousRow(SMALLINT, SMALLINT), new Cast( new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), @@ -75,15 +75,15 @@ public void testWithCast() public void testWithTryCast() { test( - new SubscriptExpression(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT)), true), new Constant(INTEGER, 1L)), + new Subscript(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT)), true), new Constant(INTEGER, 1L)), new Cast(new Constant(INTEGER, 1L), BIGINT, true)); test( - new SubscriptExpression(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT), true), new Constant(INTEGER, 1L)), + new Subscript(BIGINT, new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT), true), new Constant(INTEGER, 1L)), new Cast(new Constant(INTEGER, 1L), BIGINT, true)); test( - new SubscriptExpression(BIGINT, new Cast(new SubscriptExpression(rowType(field("x", BIGINT), field("y", BIGINT)), new Cast(new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), anonymousRow(anonymousRow(SMALLINT, SMALLINT), BIGINT), true), new Constant(INTEGER, 1L)), rowType(field("x", BIGINT), field("y", BIGINT)), true), new Constant(INTEGER, 2L)), + new Subscript(BIGINT, new Cast(new Subscript(rowType(field("x", BIGINT), field("y", BIGINT)), new Cast(new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), anonymousRow(anonymousRow(SMALLINT, SMALLINT), BIGINT), true), new Constant(INTEGER, 1L)), rowType(field("x", BIGINT), field("y", BIGINT)), true), new Constant(INTEGER, 2L)), new Cast(new Cast(new Constant(INTEGER, 2L), SMALLINT, true), BIGINT, true)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java index 4eddfe5a93ea..78643f387875 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.type.RowType; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.SetExpressionMatcher; @@ -83,14 +83,14 @@ public void testUnwrapInPredicate() .buildOrThrow(), project( ImmutableMap.builder() - .put("unwrappedValue", expression(new SubscriptExpression(INTEGER, new SymbolReference(RowType.anonymousRow(INTEGER), "rowValue"), new Constant(INTEGER, 1L)))) - .put("nonRowValue", expression(new SymbolReference(INTEGER, "nonRowValue"))) + .put("unwrappedValue", expression(new Subscript(INTEGER, new Reference(RowType.anonymousRow(INTEGER), "rowValue"), new Constant(INTEGER, 1L)))) + .put("nonRowValue", expression(new Reference(INTEGER, "nonRowValue"))) .buildOrThrow(), values("rowValue", "nonRowValue")), project( ImmutableMap.builder() - .put("unwrappedElement", expression(new SubscriptExpression(INTEGER, new SymbolReference(RowType.anonymousRow(INTEGER), "rowElement"), new Constant(INTEGER, 1L)))) - .put("nonRowElement", expression(new SymbolReference(INTEGER, "nonRowElement"))) + .put("unwrappedElement", expression(new Subscript(INTEGER, new Reference(RowType.anonymousRow(INTEGER), "rowElement"), new Constant(INTEGER, 1L)))) + .put("nonRowElement", expression(new Reference(INTEGER, "nonRowElement"))) .buildOrThrow(), values("rowElement", "nonRowElement"))))); } @@ -121,14 +121,14 @@ public void testUnwrapQuantifiedComparison() .buildOrThrow(), project( ImmutableMap.builder() - .put("unwrappedValue", expression(new SubscriptExpression(INTEGER, new SymbolReference(RowType.anonymousRow(INTEGER), "rowValue"), new Constant(INTEGER, 1L)))) - .put("nonRowValue", expression(new SymbolReference(INTEGER, "nonRowValue"))) + .put("unwrappedValue", expression(new Subscript(INTEGER, new Reference(RowType.anonymousRow(INTEGER), "rowValue"), new Constant(INTEGER, 1L)))) + .put("nonRowValue", expression(new Reference(INTEGER, "nonRowValue"))) .buildOrThrow(), values("rowValue", "nonRowValue")), project( ImmutableMap.builder() - .put("unwrappedElement", expression(new SubscriptExpression(INTEGER, new SymbolReference(RowType.anonymousRow(INTEGER), "rowElement"), new Constant(INTEGER, 1L)))) - .put("nonRowElement", expression(new SymbolReference(INTEGER, "nonRowElement"))) + .put("unwrappedElement", expression(new Subscript(INTEGER, new Reference(RowType.anonymousRow(INTEGER), "rowElement"), new Constant(INTEGER, 1L)))) + .put("nonRowElement", expression(new Reference(INTEGER, "nonRowElement"))) .buildOrThrow(), values("rowElement", "nonRowElement"))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 1707ce5058f6..82a91ae35f8c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -136,7 +136,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -528,7 +528,7 @@ public AssignUniqueId assignUniqueId(Symbol unique, PlanNode source) public CorrelatedJoinNode correlatedJoin(List correlation, PlanNode input, PlanNode subquery) { - return correlatedJoin(correlation, input, JoinType.INNER, TRUE_LITERAL, subquery); + return correlatedJoin(correlation, input, JoinType.INNER, TRUE, subquery); } public CorrelatedJoinNode correlatedJoin(List correlation, PlanNode input, JoinType type, Expression filter, PlanNode subquery) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java index 838a9fb6fd85..0cf91988072c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java @@ -20,7 +20,7 @@ import io.trino.plugin.tpch.TpchTableHandle; import io.trino.spi.connector.TestingColumnHandle; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.iterative.Rule.Context; @@ -52,7 +52,7 @@ public void testReportWrongMatch() (node, captures, context) -> Result.ofPlanNode(node.replaceChildren(node.getSources())))) .on(p -> p.project( - Assignments.of(p.symbol("y"), new SymbolReference(INTEGER, "x")), + Assignments.of(p.symbol("y"), new Reference(INTEGER, "x")), p.values( ImmutableList.of(p.symbol("x")), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L)))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java index ea0ce53552ac..008028252095 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesPlans.java @@ -29,10 +29,10 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.BigintType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; @@ -65,9 +65,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED; import static io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; @@ -310,14 +310,14 @@ public void testImplementOffsetWithOrderedSource() "SELECT name FROM nation ORDER BY regionkey, name OFFSET 5 LIMIT 2", output( project( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 5L)), rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), project( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), topN( 7, ImmutableList.of(sort("regionkey", ASCENDING, LAST), sort("name", ASCENDING, LAST)), @@ -335,9 +335,9 @@ public void testImplementOffsetWithUnorderedSource() "SELECT name FROM nation OFFSET 5 LIMIT 2", any( project( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_num"), new Constant(BIGINT, 5L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "row_num"), new Constant(BIGINT, 5L)), exchange( LOCAL, REPARTITION, @@ -364,7 +364,7 @@ public void testExchangesAroundTrivialProjection() pattern -> pattern .partitionBy(ImmutableList.of()), project( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), topN( 5, ImmutableList.of(sort("nationkey", ASCENDING, LAST)), @@ -385,9 +385,9 @@ public void testExchangesAroundTrivialProjection() LOCAL, GATHER, project( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b"))), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b"))), filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 10L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 10L)), exchange( LOCAL, REPARTITION, @@ -408,7 +408,7 @@ public void testExchangesAroundTrivialProjection() ImmutableList.of(), ImmutableSet.of("regionkey"), project( - ImmutableMap.of("regionkey", expression(new SymbolReference(BIGINT,"regionkey"))), + ImmutableMap.of("regionkey", expression(new Reference(BIGINT,"regionkey"))), topN( 5, ImmutableList.of(sort("nationkey", ASCENDING, LAST)), @@ -431,9 +431,9 @@ public void testExchangesAroundTrivialProjection() ImmutableList.of(), ImmutableSet.of("b"), project( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b"))), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b"))), filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 10L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 10L)), exchange( LOCAL, REPARTITION, @@ -449,7 +449,7 @@ public void testExchangesAroundTrivialProjection() ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of("name"))), PARTIAL, project( - ImmutableMap.of("name", expression(new SymbolReference(VARCHAR, "name"))), + ImmutableMap.of("name", expression(new Reference(VARCHAR, "name"))), exchange( LOCAL, REPARTITION, @@ -470,9 +470,9 @@ public void testExchangesAroundTrivialProjection() ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of("b"))), PARTIAL, project( - ImmutableMap.of("b", expression(new SymbolReference(INTEGER, "b"))), + ImmutableMap.of("b", expression(new Reference(INTEGER, "b"))), filter( - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 10L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 10L)), exchange( LOCAL, REPARTITION, @@ -705,7 +705,7 @@ SELECT suppkey, partkey, count(*) as count Optional.empty(), PARTIAL, project( - ImmutableMap.of("partkey_expr", expression(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "partkey"), new Constant(BIGINT, 10L)))), + ImmutableMap.of("partkey_expr", expression(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "partkey"), new Constant(BIGINT, 10L)))), tableScan("lineitem", ImmutableMap.of( "partkey", "partkey", "suppkey", "suppkey")))))))))))))); @@ -735,7 +735,7 @@ SELECT suppkey, partkey, count(*) as count Optional.empty(), Step.PARTIAL, project( - ImmutableMap.of("orderkey_expr", expression(new ArithmeticBinaryExpression(MODULUS_BIGINT, MODULUS, new SymbolReference(BIGINT, "orderkey"), new Constant(BIGINT, 10000L)))), + ImmutableMap.of("orderkey_expr", expression(new Arithmetic(MODULUS_BIGINT, MODULUS, new Reference(BIGINT, "orderkey"), new Constant(BIGINT, 10000L)))), tableScan("lineitem", ImmutableMap.of( "partkey", "partkey", "orderkey", "orderkey", diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java index 323cdc91f626..3bacf128591c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java @@ -24,8 +24,8 @@ import io.trino.spi.statistics.ColumnStatistics; import io.trino.spi.statistics.Estimate; import io.trino.spi.statistics.TableStatistics; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.DynamicFilterSourceNode; @@ -162,7 +162,7 @@ public void testSimpleFilter() output( project( filter( - new IsNullPredicate(new SymbolReference(BIGINT, "column_b")), + new IsNull(new Reference(BIGINT, "column_b")), tableScan("table_with_stats_a", ImmutableMap.of("column_a", "column_a", "column_b", "column_b")))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateCrossJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateCrossJoins.java index f84d3b5f8a66..4b23bf60e95a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateCrossJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateCrossJoins.java @@ -17,9 +17,9 @@ import io.airlift.slice.Slices; import io.trino.SystemSessionProperties; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import org.junit.jupiter.api.Test; @@ -28,10 +28,10 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -138,14 +138,14 @@ public void testEliminateCrossJoinWithNonEqualityCondition() .left( join(INNER, leftJoinBuilder -> leftJoinBuilder .equiCriteria("P_PARTKEY", "L_PARTKEY") - .filter(new ComparisonExpression(LESS_THAN, new SymbolReference(VARCHAR, "P_NAME"), new SymbolReference(VARCHAR, "expr"))) + .filter(new Comparison(LESS_THAN, new Reference(VARCHAR, "P_NAME"), new Reference(VARCHAR, "expr"))) .left(anyTree(PART_WITH_NAME_TABLESCAN)) .right( anyTree( project( - ImmutableMap.of("expr", expression(new Cast(new SymbolReference(VARCHAR, "L_COMMENT"), createVarcharType(55)))), + ImmutableMap.of("expr", expression(new Cast(new Reference(VARCHAR, "L_COMMENT"), createVarcharType(55)))), filter( - new ComparisonExpression(NOT_EQUAL, new SymbolReference(BIGINT, "L_PARTKEY"), new SymbolReference(BIGINT, "L_ORDERKEY")), + new Comparison(NOT_EQUAL, new Reference(BIGINT, "L_PARTKEY"), new Reference(BIGINT, "L_ORDERKEY")), LINEITEM_WITH_COMMENT_TABLESCAN)))))) .right(anyTree(ORDERS_TABLESCAN))))); } @@ -163,11 +163,11 @@ public void testEliminateCrossJoinPreserveFilters() .equiCriteria("P_PARTKEY", "L_PARTKEY") .left(anyTree(PART_TABLESCAN)) .right(anyTree(filter( - new ComparisonExpression(EQUAL, new SymbolReference(createVarcharType(1), "L_RETURNFLAG"), new Constant(createVarcharType(1), Slices.utf8Slice("R"))), + new Comparison(EQUAL, new Reference(createVarcharType(1), "L_RETURNFLAG"), new Constant(createVarcharType(1), Slices.utf8Slice("R"))), LINEITEM_WITH_RETURNFLAG_TABLESCAN))))) .right( anyTree(filter( - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(INTEGER, "O_SHIPPRIORITY"), new Constant(INTEGER, 10L)), + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(INTEGER, "O_SHIPPRIORITY"), new Constant(INTEGER, 10L)), ORDERS_WITH_SHIPPRIORITY_TABLESCAN)))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java index f1a5aac15135..85eb76b4cc2e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java @@ -19,9 +19,9 @@ import io.trino.cost.TaskCountEstimator; import io.trino.spi.connector.SortOrder; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.ExpectedValueProvider; @@ -37,7 +37,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; @@ -95,7 +95,7 @@ SELECT quantity, row_number() OVER (ORDER BY quantity) sort( anyTree( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "QUANTITY"), new Cast(new Constant(INTEGER, 10L), DOUBLE)), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "QUANTITY"), new Cast(new Constant(INTEGER, 10L), DOUBLE)), window(windowMatcherBuilder -> windowMatcherBuilder .specification(windowSpec) .addFunction(windowFunction("row_number", ImmutableList.of(), DEFAULT_FRAME)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java index 38aeaab27fa3..78fc76695e8a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java @@ -22,13 +22,13 @@ import io.trino.security.AllowAllAccessControl; import io.trino.spi.type.Decimals; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.transaction.TestingTransactionManager; import io.trino.transaction.TransactionManager; @@ -49,17 +49,17 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.BooleanLiteral.FALSE_LITERAL; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.testing.TransactionBuilder.transaction; @@ -87,11 +87,11 @@ public void testEquivalent() new Constant(BIGINT, null), new Constant(BIGINT, null)); assertEquivalent( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(DOUBLE, "b_double")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "b_double"), new SymbolReference(BIGINT, "a_bigint"))); + new Comparison(LESS_THAN, new Reference(BIGINT, "a_bigint"), new Reference(DOUBLE, "b_double")), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "b_double"), new Reference(BIGINT, "a_bigint"))); assertEquivalent( - TRUE_LITERAL, - TRUE_LITERAL); + TRUE, + TRUE); assertEquivalent( new Constant(INTEGER, 4L), new Constant(INTEGER, 4L)); @@ -103,104 +103,104 @@ public void testEquivalent() new Constant(VARCHAR, Slices.utf8Slice("foo"))); assertEquivalent( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); + new Comparison(EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); assertEquivalent( - new ComparisonExpression(EQUAL, new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("4.4"))), new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("5.5")))), - new ComparisonExpression(EQUAL, new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("5.5"))), new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("4.4"))))); + new Comparison(EQUAL, new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("4.4"))), new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("5.5")))), + new Comparison(EQUAL, new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("5.5"))), new Constant(createDecimalType(3, 1), Decimals.valueOfShort(new BigDecimal("4.4"))))); assertEquivalent( - new ComparisonExpression(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("foo")), new Constant(VARCHAR, Slices.utf8Slice("bar"))), - new ComparisonExpression(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("bar")), new Constant(VARCHAR, Slices.utf8Slice("foo")))); + new Comparison(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("foo")), new Constant(VARCHAR, Slices.utf8Slice("bar"))), + new Comparison(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("bar")), new Constant(VARCHAR, Slices.utf8Slice("foo")))); assertEquivalent( - new ComparisonExpression(NOT_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(NOT_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); + new Comparison(NOT_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(NOT_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); assertEquivalent( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); assertEquivalent( - new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); + new Comparison(LESS_THAN, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); assertEquivalent( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); + new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))); assertEquivalent( - new ComparisonExpression(EQUAL, new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-05-10 12:34:56.123456789")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2021-05-10 12:34:56.123456789"))), - new ComparisonExpression(EQUAL, new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2021-05-10 12:34:56.123456789")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-05-10 12:34:56.123456789")))); + new Comparison(EQUAL, new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-05-10 12:34:56.123456789")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2021-05-10 12:34:56.123456789"))), + new Comparison(EQUAL, new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2021-05-10 12:34:56.123456789")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-05-10 12:34:56.123456789")))); assertEquivalent( - new ComparisonExpression(EQUAL, new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2020-05-10 12:34:56.123456789 +8")), new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2021-05-10 12:34:56.123456789 +8"))), - new ComparisonExpression(EQUAL, new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2021-05-10 12:34:56.123456789 +8")), new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2020-05-10 12:34:56.123456789 +8")))); + new Comparison(EQUAL, new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2020-05-10 12:34:56.123456789 +8")), new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2021-05-10 12:34:56.123456789 +8"))), + new Comparison(EQUAL, new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2021-05-10 12:34:56.123456789 +8")), new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2020-05-10 12:34:56.123456789 +8")))); assertEquivalent( - new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), - new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))); + new Call(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), + new Call(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))); assertEquivalent( - new SymbolReference(BIGINT, "a_bigint"), - new SymbolReference(BIGINT, "a_bigint")); + new Reference(BIGINT, "a_bigint"), + new Reference(BIGINT, "a_bigint")); assertEquivalent( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "a_bigint"))); + new Comparison(EQUAL, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), + new Comparison(EQUAL, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "a_bigint"))); assertEquivalent( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "a_bigint"))); + new Comparison(LESS_THAN, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "a_bigint"))); assertEquivalent( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(DOUBLE, "b_double")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "b_double"), new SymbolReference(BIGINT, "a_bigint"))); + new Comparison(LESS_THAN, new Reference(BIGINT, "a_bigint"), new Reference(DOUBLE, "b_double")), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "b_double"), new Reference(BIGINT, "a_bigint"))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(TRUE_LITERAL, FALSE_LITERAL)), - new LogicalExpression(AND, ImmutableList.of(FALSE_LITERAL, TRUE_LITERAL))); + new Logical(AND, ImmutableList.of(TRUE, FALSE)), + new Logical(AND, ImmutableList.of(FALSE, TRUE))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); assertEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "c_bigint"), new SymbolReference(BIGINT, "d_bigint")))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "d_bigint"), new SymbolReference(BIGINT, "c_bigint")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "a_bigint"))))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), new Comparison(LESS_THAN, new Reference(BIGINT, "c_bigint"), new Reference(BIGINT, "d_bigint")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "d_bigint"), new Reference(BIGINT, "c_bigint")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "a_bigint"))))); assertEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "c_bigint"), new SymbolReference(BIGINT, "d_bigint")))), - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "d_bigint"), new SymbolReference(BIGINT, "c_bigint")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "a_bigint"))))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), new Comparison(LESS_THAN, new Reference(BIGINT, "c_bigint"), new Reference(BIGINT, "d_bigint")))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "d_bigint"), new Reference(BIGINT, "c_bigint")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "a_bigint"))))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))), + new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 3L), new Constant(INTEGER, 2L))))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 3L), new Constant(INTEGER, 2L))))); assertEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))), + new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))); assertEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L))))); assertEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L)), new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 3L), new Constant(INTEGER, 2L))))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 2L), new Constant(INTEGER, 3L)), new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 3L), new Constant(INTEGER, 2L))))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "a_boolean"), new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "c_boolean"))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "c_boolean"), new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "a_boolean")))); + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "a_boolean"), new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "c_boolean"))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "c_boolean"), new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "a_boolean")))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "a_boolean"), new SymbolReference(BOOLEAN, "b_boolean"))), new SymbolReference(BOOLEAN, "c_boolean"))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "c_boolean"), new SymbolReference(BOOLEAN, "b_boolean"))), new SymbolReference(BOOLEAN, "a_boolean")))); + new Logical(AND, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "a_boolean"), new Reference(BOOLEAN, "b_boolean"))), new Reference(BOOLEAN, "c_boolean"))), + new Logical(AND, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "c_boolean"), new Reference(BOOLEAN, "b_boolean"))), new Reference(BOOLEAN, "a_boolean")))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "a_boolean"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "c_boolean"))))), - new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "a_boolean"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "c_boolean"), new SymbolReference(BOOLEAN, "b_boolean"))), new SymbolReference(BOOLEAN, "a_boolean")))); + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "a_boolean"), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "c_boolean"))))), + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "a_boolean"), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "c_boolean"), new Reference(BOOLEAN, "b_boolean"))), new Reference(BOOLEAN, "a_boolean")))); assertEquivalent( - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "a_boolean"), new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "c_boolean"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "d_boolean"), new SymbolReference(BOOLEAN, "e_boolean"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "f_boolean"), new SymbolReference(BOOLEAN, "g_boolean"), new SymbolReference(BOOLEAN, "h_boolean"))))), - new LogicalExpression(AND, ImmutableList.of(new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "h_boolean"), new SymbolReference(BOOLEAN, "g_boolean"), new SymbolReference(BOOLEAN, "f_boolean"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "a_boolean"), new SymbolReference(BOOLEAN, "c_boolean"))), new LogicalExpression(OR, ImmutableList.of(new SymbolReference(BOOLEAN, "e_boolean"), new SymbolReference(BOOLEAN, "d_boolean")))))); + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "a_boolean"), new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "c_boolean"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "d_boolean"), new Reference(BOOLEAN, "e_boolean"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "f_boolean"), new Reference(BOOLEAN, "g_boolean"), new Reference(BOOLEAN, "h_boolean"))))), + new Logical(AND, ImmutableList.of(new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "h_boolean"), new Reference(BOOLEAN, "g_boolean"), new Reference(BOOLEAN, "f_boolean"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "a_boolean"), new Reference(BOOLEAN, "c_boolean"))), new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "e_boolean"), new Reference(BOOLEAN, "d_boolean")))))); assertEquivalent( - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "a_boolean"), new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "c_boolean"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "d_boolean"), new SymbolReference(BOOLEAN, "e_boolean"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "f_boolean"), new SymbolReference(BOOLEAN, "g_boolean"), new SymbolReference(BOOLEAN, "h_boolean"))))), - new LogicalExpression(OR, ImmutableList.of(new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "h_boolean"), new SymbolReference(BOOLEAN, "g_boolean"), new SymbolReference(BOOLEAN, "f_boolean"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "b_boolean"), new SymbolReference(BOOLEAN, "a_boolean"), new SymbolReference(BOOLEAN, "c_boolean"))), new LogicalExpression(AND, ImmutableList.of(new SymbolReference(BOOLEAN, "e_boolean"), new SymbolReference(BOOLEAN, "d_boolean")))))); + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "a_boolean"), new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "c_boolean"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "d_boolean"), new Reference(BOOLEAN, "e_boolean"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "f_boolean"), new Reference(BOOLEAN, "g_boolean"), new Reference(BOOLEAN, "h_boolean"))))), + new Logical(OR, ImmutableList.of(new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "h_boolean"), new Reference(BOOLEAN, "g_boolean"), new Reference(BOOLEAN, "f_boolean"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "b_boolean"), new Reference(BOOLEAN, "a_boolean"), new Reference(BOOLEAN, "c_boolean"))), new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "e_boolean"), new Reference(BOOLEAN, "d_boolean")))))); } private static void assertEquivalent(Expression leftExpression, Expression rightExpression) @@ -220,13 +220,13 @@ public void testNotEquivalent() { assertNotEquivalent( new Constant(BOOLEAN, null), - FALSE_LITERAL); + FALSE); assertNotEquivalent( - FALSE_LITERAL, + FALSE, new Constant(BOOLEAN, null)); assertNotEquivalent( - TRUE_LITERAL, - FALSE_LITERAL); + TRUE, + FALSE); assertNotEquivalent( new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)); @@ -238,51 +238,51 @@ public void testNotEquivalent() new Constant(VARCHAR, Slices.utf8Slice("'bar'"))); assertNotEquivalent( - new ComparisonExpression(EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); + new Comparison(EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); assertNotEquivalent( - new ComparisonExpression(NOT_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(NOT_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); + new Comparison(NOT_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(NOT_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); assertNotEquivalent( - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(IS_DISTINCT_FROM, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(IS_DISTINCT_FROM, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); assertNotEquivalent( - new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); + new Comparison(LESS_THAN, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); assertNotEquivalent( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), - new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); + new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); assertNotEquivalent( - new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), - new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)))); + new Call(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), + new Call(MOD, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)))); assertNotEquivalent( - new SymbolReference(BIGINT, "a_bigint"), - new SymbolReference(BIGINT, "b_bigint")); + new Reference(BIGINT, "a_bigint"), + new Reference(BIGINT, "b_bigint")); assertNotEquivalent( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "c_bigint"))); + new Comparison(EQUAL, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), + new Comparison(EQUAL, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "c_bigint"))); assertNotEquivalent( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "c_bigint"))); + new Comparison(LESS_THAN, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), + new Comparison(GREATER_THAN, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "c_bigint"))); assertNotEquivalent( - new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(DOUBLE, "b_double")), - new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "b_double"), new SymbolReference(BIGINT, "c_bigint"))); + new Comparison(LESS_THAN, new Reference(BIGINT, "a_bigint"), new Reference(DOUBLE, "b_double")), + new Comparison(GREATER_THAN, new Reference(DOUBLE, "b_double"), new Reference(BIGINT, "c_bigint"))); assertNotEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))))); assertNotEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new ComparisonExpression(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)), new Comparison(LESS_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 7L)))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Constant(INTEGER, 7L), new Constant(INTEGER, 6L)), new Comparison(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))))); assertNotEquivalent( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "c_bigint"), new SymbolReference(BIGINT, "d_bigint")))), - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "d_bigint"), new SymbolReference(BIGINT, "c_bigint")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "c_bigint"))))); + new Logical(AND, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), new Comparison(LESS_THAN, new Reference(BIGINT, "c_bigint"), new Reference(BIGINT, "d_bigint")))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "d_bigint"), new Reference(BIGINT, "c_bigint")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "c_bigint"))))); assertNotEquivalent( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference(BIGINT, "a_bigint"), new SymbolReference(BIGINT, "b_bigint")), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "c_bigint"), new SymbolReference(BIGINT, "d_bigint")))), - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "d_bigint"), new SymbolReference(BIGINT, "c_bigint")), new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(BIGINT, "b_bigint"), new SymbolReference(BIGINT, "c_bigint"))))); + new Logical(OR, ImmutableList.of(new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a_bigint"), new Reference(BIGINT, "b_bigint")), new Comparison(LESS_THAN, new Reference(BIGINT, "c_bigint"), new Reference(BIGINT, "d_bigint")))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "d_bigint"), new Reference(BIGINT, "c_bigint")), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b_bigint"), new Reference(BIGINT, "c_bigint"))))); assertNotEquivalent( new Cast(new Constant(createTimeWithTimeZoneType(3), DateTimes.parseTimeWithTimeZone(3, "12:34:56.123 +00:00")), VARCHAR), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java index de3135b0df7b..ca561070ee9e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java @@ -18,16 +18,16 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.CoalesceExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; @@ -62,13 +62,13 @@ public void testFullOuterJoinWithCoalesce() "FULL OUTER JOIN (VALUES 2, 5) r(a) on ts.a = r.a", anyTree( project( - ImmutableMap.of("expr", expression(new CoalesceExpression(new SymbolReference(INTEGER, "ts"), new SymbolReference(INTEGER, "r")))), + ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "ts"), new Reference(INTEGER, "r")))), join(FULL, builder -> builder .equiCriteria("ts", "r") .left( anyTree( project( - ImmutableMap.of("ts", expression(new CoalesceExpression(new SymbolReference(INTEGER, "t"), new SymbolReference(INTEGER, "s")))), + ImmutableMap.of("ts", expression(new Coalesce(new Reference(INTEGER, "t"), new Reference(INTEGER, "s")))), join(FULL, leftJoinBuilder -> leftJoinBuilder .equiCriteria("t", "s") .left(exchange(REMOTE, REPARTITION, anyTree(values(ImmutableList.of("t"))))) @@ -97,7 +97,7 @@ public void testArgumentsInDifferentOrder() PARTIAL, anyTree( project( - ImmutableMap.of("expr", expression(new CoalesceExpression(new SymbolReference(INTEGER, "l"), new SymbolReference(INTEGER, "r")))), + ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "l"), new Reference(INTEGER, "r")))), join(FULL, builder -> builder .equiCriteria("l", "r") .left(anyTree(values(ImmutableList.of("l")))) @@ -117,7 +117,7 @@ public void testArgumentsInDifferentOrder() PARTIAL, anyTree( project( - ImmutableMap.of("expr", expression(new CoalesceExpression(new SymbolReference(INTEGER, "r"), new SymbolReference(INTEGER, "l")))), + ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "r"), new Reference(INTEGER, "l")))), join(FULL, builder -> builder .equiCriteria("l", "r") .left(anyTree(values(ImmutableList.of("l")))) @@ -144,7 +144,7 @@ public void testCoalesceWithManyArguments() ImmutableMap.of(), PARTIAL, project( - ImmutableMap.of("expr", expression(new CoalesceExpression(new SymbolReference(INTEGER, "l"), new SymbolReference(INTEGER, "m"), new SymbolReference(INTEGER, "r")))), + ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "l"), new Reference(INTEGER, "m"), new Reference(INTEGER, "r")))), join(FULL, builder -> builder .equiCriteria("l", "r") .left( @@ -173,7 +173,7 @@ public void testComplexArgumentToCoalesce() ImmutableMap.of(), PARTIAL, project( - ImmutableMap.of("expr", expression(new CoalesceExpression(new SymbolReference(INTEGER, "l"), new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "m"), new Constant(INTEGER, 1L)), new SymbolReference(INTEGER, "r")))), + ImmutableMap.of("expr", expression(new Coalesce(new Reference(INTEGER, "l"), new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "m"), new Constant(INTEGER, 1L)), new Reference(INTEGER, "r")))), join(FULL, builder -> builder .equiCriteria("l", "r") .left( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java index 637159c62f13..3e86754d1cf6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java @@ -18,11 +18,11 @@ import com.google.common.collect.ImmutableSet; import io.trino.spi.connector.SortOrder; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.ExpectedValueProvider; @@ -44,7 +44,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.PlanOptimizers.columnPruningRules; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -211,7 +211,7 @@ public void testIdenticalWindowSpecificationsABcpA() .specification(specificationB) .addFunction(windowFunction("nth_value", ImmutableList.of(QUANTITY_ALIAS, "ONE"), COMMON_FRAME)), project( - ImmutableMap.of("ONE", expression(new Cast(new SymbolReference(INTEGER, "expr"), BIGINT))), + ImmutableMap.of("ONE", expression(new Cast(new Reference(INTEGER, "expr"), BIGINT))), project( ImmutableMap.of("expr", expression(new Constant(INTEGER, 1L))), LINEITEM_TABLESCAN_DOQSS))))))); @@ -243,7 +243,7 @@ public void testIdenticalWindowSpecificationsABfilterA() window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationB) .addFunction(windowFunction("sum", ImmutableList.of(QUANTITY_ALIAS), COMMON_FRAME)), - filter(new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, SHIPDATE_ALIAS))), + filter(new Not(new IsNull(new Reference(VARCHAR, SHIPDATE_ALIAS))), project( window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationA) @@ -267,7 +267,7 @@ public void testIdenticalWindowSpecificationsAAcpA() .addFunction(windowFunction("sum", ImmutableList.of(DISCOUNT_ALIAS), COMMON_FRAME)) .addFunction(windowFunction("nth_value", ImmutableList.of(QUANTITY_ALIAS, "ONE"), COMMON_FRAME)) .addFunction(windowFunction("sum", ImmutableList.of(QUANTITY_ALIAS), COMMON_FRAME)), - project(ImmutableMap.of("ONE", expression(new Cast(new SymbolReference(INTEGER, "expr"), BIGINT))), + project(ImmutableMap.of("ONE", expression(new Cast(new Reference(INTEGER, "expr"), BIGINT))), project(ImmutableMap.of("expr", expression(new Constant(INTEGER, 1L))), LINEITEM_TABLESCAN_DOQS))))); } @@ -294,7 +294,7 @@ public void testIdenticalWindowSpecificationsAAfilterA() .addFunction(windowFunction("sum", ImmutableList.of(QUANTITY_ALIAS), COMMON_FRAME)) .addFunction(windowFunction("avg", ImmutableList.of(QUANTITY_ALIAS), COMMON_FRAME)), project( - filter(new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, SHIPDATE_ALIAS))), + filter(new Not(new IsNull(new Reference(VARCHAR, SHIPDATE_ALIAS))), project( window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationA) @@ -462,7 +462,7 @@ public void testNotMergeAcrossJoinBranches() assertUnitPlan(sql, anyTree( - filter(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "SUM"), new SymbolReference(BIGINT, "AVG")), + filter(new Comparison(EQUAL, new Reference(BIGINT, "SUM"), new Reference(BIGINT, "AVG")), join(INNER, builder -> builder .left( any( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java index 5a5421f27f06..af7c44fffcf8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java @@ -27,10 +27,10 @@ import io.trino.spi.connector.SortingProperty; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -44,7 +44,7 @@ import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -208,9 +208,9 @@ public void testNestedField() topN(1, ImmutableList.of(sort("k", ASCENDING, LAST)), FINAL, anyTree( limit(1, ImmutableList.of(), true, ImmutableList.of("k"), - project(ImmutableMap.of("k", expression(new SubscriptExpression(INTEGER, new SymbolReference(INTEGER, "nested"), new Constant(INTEGER, 1L)))), + project(ImmutableMap.of("k", expression(new Subscript(INTEGER, new Reference(INTEGER, "nested"), new Constant(INTEGER, 1L)))), filter( - new ComparisonExpression(EQUAL, new SubscriptExpression(INTEGER, new SymbolReference(INTEGER, "nested"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new Subscript(INTEGER, new Reference(INTEGER, "nested"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), tableScan("with_nested_field", ImmutableMap.of("nested", "nested"))))))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java index 79bcab6bcaba..2c46cba60bba 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java @@ -28,14 +28,14 @@ import io.trino.spi.connector.CatalogHandle; import io.trino.spi.function.OperatorType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -64,9 +64,9 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.combineDisjuncts; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -123,7 +123,7 @@ public void testUnconsumedDynamicFilterInJoin() PlanNode root = builder.join( INNER, builder.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), ordersTableScanNode), lineitemTableScanNode, ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), @@ -139,8 +139,8 @@ public void testUnconsumedDynamicFilterInJoin() .equiCriteria("ORDERS_OK", "LINEITEM_OK") .left( PlanMatchPattern.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), - TRUE_LITERAL, + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + TRUE, tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))); @@ -172,8 +172,8 @@ public void testDynamicFilterConsumedOnBuildSide() .dynamicFilter(BIGINT, "ORDERS_OK", "LINEITEM_OK") .left( PlanMatchPattern.filter( - TRUE_LITERAL, - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new SymbolReference(BIGINT, "ORDERS_OK")), + TRUE, + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Reference(BIGINT, "ORDERS_OK")), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))) .right( tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))); @@ -190,7 +190,7 @@ public void testUnmatchedDynamicFilter() ordersTableScanNode, builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), @@ -209,7 +209,7 @@ public void testUnmatchedDynamicFilter() tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))) .right( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))); } @@ -223,7 +223,7 @@ public void testRemoveDynamicFilterNotAboveTableScan() INNER, builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(lineitemOrderKeySymbol)), ordersTableScanNode, @@ -241,7 +241,7 @@ public void testRemoveDynamicFilterNotAboveTableScan() .equiCriteria("LINEITEM_OK", "ORDERS_OK") .left( filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), values("LINEITEM_OK"))) .right( tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))))); @@ -259,10 +259,10 @@ public void testNestedDynamicFilterDisjunctionRewrite() builder.filter( combineConjuncts( combineDisjuncts( - new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK")), + new IsNull(new Reference(BIGINT, "LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), combineDisjuncts( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK"))), + new Not(new IsNull(new Reference(BIGINT, "LINEITEM_OK"))), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference()))), lineitemTableScanNode), ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), @@ -293,10 +293,10 @@ public void testNestedDynamicFilterConjunctionRewrite() builder.filter( combineDisjuncts( combineConjuncts( - new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK")), + new IsNull(new Reference(BIGINT, "LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), combineConjuncts( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK"))), + new Not(new IsNull(new Reference(BIGINT, "LINEITEM_OK"))), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference()))), lineitemTableScanNode), ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), @@ -316,8 +316,8 @@ public void testNestedDynamicFilterConjunctionRewrite() .right( filter( combineDisjuncts( - new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK")), - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK")))), + new IsNull(new Reference(BIGINT, "LINEITEM_OK")), + new Not(new IsNull(new Reference(BIGINT, "LINEITEM_OK")))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))); } @@ -331,7 +331,7 @@ public void testRemoveUnsupportedCast() builder.join( INNER, builder.filter( - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Cast(new SymbolReference(DOUBLE, "LINEITEM_DOUBLE_OK"), BIGINT)), + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Cast(new Reference(DOUBLE, "LINEITEM_DOUBLE_OK"), BIGINT)), builder.tableScan( lineitemTableHandle, ImmutableList.of(lineitemDoubleOrderKeySymbol), @@ -368,12 +368,12 @@ public void testSpatialJoin() builder.values(leftSymbol), builder.values(rightSymbol), ImmutableList.of(leftSymbol, rightSymbol), - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "LEFT_SYMBOL"), new SymbolReference(INTEGER, "RIGHT_SYMBOL"))))); + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "LEFT_SYMBOL"), new Reference(INTEGER, "RIGHT_SYMBOL"))))); assertPlan( removeUnsupportedDynamicFilters(root), output( spatialJoin( - TRUE_LITERAL, + TRUE, values("LEFT_SYMBOL"), values("RIGHT_SYMBOL")))); } @@ -383,7 +383,7 @@ public void testUnconsumedDynamicFilterInSemiJoin() { PlanNode root = builder.semiJoin( builder.filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), ordersTableScanNode), lineitemTableScanNode, ordersOrderKeySymbol, @@ -397,7 +397,7 @@ public void testUnconsumedDynamicFilterInSemiJoin() removeUnsupportedDynamicFilters(root), semiJoin("ORDERS_OK", "LINEITEM_OK", "SEMIJOIN_OUTPUT", false, filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))); } @@ -409,7 +409,7 @@ public void testDynamicFilterConsumedOnFilteringSourceSideInSemiJoin() ordersTableScanNode, builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), ordersOrderKeySymbol, @@ -424,7 +424,7 @@ public void testDynamicFilterConsumedOnFilteringSourceSideInSemiJoin() semiJoin("ORDERS_OK", "LINEITEM_OK", "SEMIJOIN_OUTPUT", false, tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))); } @@ -434,7 +434,7 @@ public void testUnmatchedDynamicFilterInSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), ordersTableScanNode), lineitemTableScanNode, @@ -449,7 +449,7 @@ public void testUnmatchedDynamicFilterInSemiJoin() removeUnsupportedDynamicFilters(root), semiJoin("ORDERS_OK", "LINEITEM_OK", "SEMIJOIN_OUTPUT", false, filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))); } @@ -460,7 +460,7 @@ public void testRemoveDynamicFilterNotAboveTableScanWithSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(ordersOrderKeySymbol)), lineitemTableScanNode, @@ -476,7 +476,7 @@ public void testRemoveDynamicFilterNotAboveTableScanWithSemiJoin() removeUnsupportedDynamicFilters(root), semiJoin("ORDERS_OK", "LINEITEM_OK", "SEMIJOIN_OUTPUT", false, filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), values("ORDERS_OK")), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))); } @@ -484,7 +484,7 @@ public void testRemoveDynamicFilterNotAboveTableScanWithSemiJoin() private static PlanMatchPattern filter(Expression expectedPredicate, PlanMatchPattern source) { // assert explicitly that no dynamic filters are present - return PlanMatchPattern.filter(expectedPredicate, TRUE_LITERAL, source); + return PlanMatchPattern.filter(expectedPredicate, TRUE, source); } private PlanNode removeUnsupportedDynamicFilters(PlanNode root) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java index 6f536c27d8f3..88b77c6b55db 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java @@ -18,11 +18,11 @@ import com.google.common.collect.ImmutableSet; import io.trino.spi.connector.SortOrder; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.ExpectedValueProvider; @@ -40,7 +40,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.PlanOptimizers.columnPruningRules; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -235,7 +235,7 @@ public void testReorderAcrossProjectNodes() window(windowMatcherBuilder -> windowMatcherBuilder .specification(windowA) .addFunction(windowFunction("lag", ImmutableList.of(QUANTITY_ALIAS, "ONE"), DEFAULT_FRAME)), - project(ImmutableMap.of("ONE", expression(new Cast(new SymbolReference(INTEGER, "expr"), BIGINT))), + project(ImmutableMap.of("ONE", expression(new Cast(new Reference(INTEGER, "expr"), BIGINT))), project(ImmutableMap.of("expr", expression(new Constant(INTEGER, 1L))), LINEITEM_TABLESCAN_DOQRST))))))); } @@ -261,7 +261,7 @@ public void testNotReorderAcrossNonPartitionFilter() .addFunction(windowFunction("avg", ImmutableList.of(QUANTITY_ALIAS), DEFAULT_FRAME)), project( filter( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, RECEIPTDATE_ALIAS))), + new Not(new IsNull(new Reference(VARCHAR, RECEIPTDATE_ALIAS))), project( window(windowMatcherBuilder -> windowMatcherBuilder .specification(windowApp) @@ -293,7 +293,7 @@ public void testReorderAcrossPartitionFilter() .specification(windowA) .addFunction(windowFunction("avg", ImmutableList.of(QUANTITY_ALIAS), DEFAULT_FRAME)), filter( - new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "SUPPKEY"), new Constant(BIGINT, 0L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "SUPPKEY"), new Constant(BIGINT, 0L)), LINEITEM_TABLESCAN_DOQRST)))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java index f493935190b4..94fd2557b1ac 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java @@ -26,7 +26,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.type.BigintType; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -47,7 +47,7 @@ import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.groupId; @@ -81,7 +81,7 @@ public void testDynamicFilterIdUnAliased() return p.join( INNER, p.filter( - TRUE_LITERAL, // additional filter to test recursive call + TRUE, // additional filter to test recursive call p.filter( and( dynamicFilterExpression(metadata, probeColumn1, dynamicFilterId1), @@ -105,13 +105,13 @@ probeColumn2, new TpchColumnHandle("suppkey", BIGINT))))), }, join(INNER, builder -> builder .dynamicFilter(ImmutableMap.of( - new SymbolReference(BIGINT, "probeColumn1"), "column", - new SymbolReference(BIGINT, "probeColumn2"), "column")) + new Reference(BIGINT, "probeColumn1"), "column", + new Reference(BIGINT, "probeColumn2"), "column")) .left( filter( - TRUE_LITERAL, + TRUE, filter( - TRUE_LITERAL, + TRUE, tableScan( probeTable, ImmutableMap.of("probeColumn1", "suppkey", "probeColumn2", "nationkey"))))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestWindowFilterPushDown.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestWindowFilterPushDown.java index 16e94e176a3d..4e5dcfec609c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestWindowFilterPushDown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestWindowFilterPushDown.java @@ -16,10 +16,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher; @@ -34,9 +34,9 @@ import static io.trino.SystemSessionProperties.OPTIMIZE_TOP_N_RANKING; import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyNot; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -235,7 +235,7 @@ private void assertFilterAboveWindow(String rankingFunction, RankingType ranking output( ImmutableList.of("name", "ranking"), filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 1L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "ranking"), new Constant(BIGINT, 3L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 1L)), new Comparison(LESS_THAN, new Reference(BIGINT, "ranking"), new Constant(BIGINT, 3L)))), topNRanking( pattern -> pattern .rankingType(rankingType) @@ -298,7 +298,7 @@ public void testFilterAboveRowNumber() output( ImmutableList.of("name", "row_number"), filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 1L)), new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "row_number"), new Constant(BIGINT, 3L)))), + new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 1L)), new Comparison(LESS_THAN, new Reference(BIGINT, "row_number"), new Constant(BIGINT, 3L)))), rowNumber( pattern -> pattern .maxRowCountPerPartition(Optional.of(2)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java index 6cb4a8a98948..b0ce7606e1f7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java @@ -28,11 +28,11 @@ import io.trino.spi.type.TestingTypeManager; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolKeyDeserializer; import io.trino.sql.planner.plan.PatternRecognitionNode.Measure; @@ -60,7 +60,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.planner.plan.FrameBoundType.CURRENT_ROW; import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_FOLLOWING; @@ -138,9 +138,9 @@ public void testExpressionAndValuePointersRoundtrip() assertJsonRoundTrip(EXPRESSION_AND_VALUE_POINTERS_CODEC, new ExpressionAndValuePointers( ifExpression( - new ComparisonExpression(GREATER_THAN, new SymbolReference(VARCHAR, "classifier"), new SymbolReference(VARCHAR, "x")), - new FunctionCall(RANDOM, ImmutableList.of()), - new ArithmeticNegation(new SymbolReference(INTEGER, "match_number"))), + new Comparison(GREATER_THAN, new Reference(VARCHAR, "classifier"), new Reference(VARCHAR, "x")), + new Call(RANDOM, ImmutableList.of()), + new Negation(new Reference(INTEGER, "match_number"))), ImmutableList.of( new ExpressionAndValuePointers.Assignment( new Symbol(VARCHAR, "classifier"), @@ -166,9 +166,9 @@ public void testMeasureRoundtrip() assertJsonRoundTrip(MEASURE_CODEC, new Measure( new ExpressionAndValuePointers( ifExpression( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "match_number"), new SymbolReference(INTEGER, "x")), + new Comparison(GREATER_THAN, new Reference(INTEGER, "match_number"), new Reference(INTEGER, "x")), new Constant(BIGINT, 10L), - new ArithmeticNegation(new SymbolReference(INTEGER, "y"))), + new Negation(new Reference(INTEGER, "y"))), ImmutableList.of( new ExpressionAndValuePointers.Assignment( new Symbol(BIGINT, "match_number"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java index 8dd909cb5023..9cc20709a75c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java @@ -28,8 +28,9 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.JoinNode; @@ -61,7 +62,6 @@ import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.planprinter.JsonRenderer.JsonRenderedNode; -import static io.trino.sql.planner.planprinter.NodeRepresentation.TypedSymbol.typedSymbol; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -108,7 +108,7 @@ public void testAggregationPlan() assertAnonymizedRepresentation( pb -> pb.aggregation(ab -> ab .step(FINAL) - .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)))), new JsonRenderedNode( @@ -119,17 +119,17 @@ public void testAggregationPlan() "keys", "[symbol_1, symbol_2]", "hash", "[]"), ImmutableList.of( - typedSymbol("symbol_1", BIGINT), - typedSymbol("symbol_2", BIGINT), - typedSymbol("symbol_3", BIGINT)), + new Symbol(BIGINT, "symbol_1"), + new Symbol(BIGINT, "symbol_2"), + new Symbol(BIGINT, "symbol_3")), ImmutableList.of("symbol_3 := sum(\"symbol_4\")"), ImmutableList.of(), ImmutableList.of(valuesRepresentation( "0", ImmutableList.of( - typedSymbol("symbol_4", BIGINT), - typedSymbol("symbol_1", BIGINT), - typedSymbol("symbol_2", BIGINT)))))); + new Symbol(BIGINT, "symbol_4"), + new Symbol(BIGINT, "symbol_1"), + new Symbol(BIGINT, "symbol_2")))))); } @Test @@ -153,20 +153,20 @@ public void testJoinPlan() ImmutableMap.of( "criteria", "(\"symbol_1\" = \"symbol_2\")", "hash", "[]"), - ImmutableList.of(typedSymbol("symbol_3", BIGINT)), + ImmutableList.of(new Symbol(BIGINT, "symbol_3")), ImmutableList.of("dynamicFilterAssignments = {symbol_2 -> #DF}"), ImmutableList.of(), ImmutableList.of( valuesRepresentation( "0", ImmutableList.of( - typedSymbol("symbol_1", BIGINT), - typedSymbol("symbol_3", BIGINT))), + new Symbol(BIGINT, "symbol_1"), + new Symbol(BIGINT, "symbol_3"))), valuesRepresentation( "1", ImmutableList.of( - typedSymbol("symbol_4", BIGINT), - typedSymbol("symbol_2", BIGINT)))))); + new Symbol(BIGINT, "symbol_4"), + new Symbol(BIGINT, "symbol_2")))))); } @Test @@ -191,10 +191,10 @@ public void testTableScanPlan() ImmutableMap.of( "table", "[table = catalog_1.schema_1.table_1, connector = tpch]"), ImmutableList.of( - typedSymbol("symbol_1", BIGINT), - typedSymbol("symbol_2", BIGINT), - typedSymbol("symbol_3", BIGINT), - typedSymbol("symbol_4", BIGINT)), + new Symbol(BIGINT, "symbol_1"), + new Symbol(BIGINT, "symbol_2"), + new Symbol(BIGINT, "symbol_3"), + new Symbol(BIGINT, "symbol_4")), ImmutableList.of( "symbol_1 := column_1", " :: [[bigint_value_1]]", @@ -216,16 +216,16 @@ public void testSortPlan() "1", "Sort", ImmutableMap.of("orderBy", "[symbol_1 ASC NULLS FIRST]"), - ImmutableList.of(typedSymbol("symbol_1", BIGINT)), + ImmutableList.of(new Symbol(BIGINT, "symbol_1")), ImmutableList.of(), ImmutableList.of(), ImmutableList.of( valuesRepresentation( "0", - ImmutableList.of(typedSymbol("symbol_1", BIGINT)))))); + ImmutableList.of(new Symbol(BIGINT, "symbol_1")))))); } - private static JsonRenderedNode valuesRepresentation(String id, List outputs) + private static JsonRenderedNode valuesRepresentation(String id, List outputs) { return new JsonRenderedNode( id, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java index 94d584a0efc2..3106dd99f3ee 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java @@ -14,19 +14,19 @@ package io.trino.sql.planner.planprinter; import com.google.common.collect.ImmutableList; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.type.UnknownType; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.AND; import static org.assertj.core.api.Assertions.assertThat; public class TestCounterBasedAnonymizer @@ -42,10 +42,10 @@ public void testTimestampWithTimeZoneValueAnonymization() @Test public void testSymbolReferenceAnonymization() { - LogicalExpression expression = new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 1L)), - new ComparisonExpression(LESS_THAN, new SymbolReference(INTEGER, "b"), new Constant(INTEGER, 2L)), - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "c"), new Constant(INTEGER, 3L)))); + Logical expression = new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 1L)), + new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 2L)), + new Comparison(EQUAL, new Reference(INTEGER, "c"), new Constant(INTEGER, 3L)))); CounterBasedAnonymizer anonymizer = new CounterBasedAnonymizer(); assertThat(anonymizer.anonymize(expression)) .isEqualTo("((\"symbol_1\" > 'integer_literal_1') AND (\"symbol_2\" < 'integer_literal_2') AND (\"symbol_3\" = 'integer_literal_3'))"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java index 63052e0c7470..b0b518c1399d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java @@ -22,9 +22,10 @@ import io.trino.metadata.QualifiedObjectName; import io.trino.plugin.tpch.TpchPlugin; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.JoinNode; @@ -51,14 +52,12 @@ import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.planner.iterative.rule.test.PlanBuilder.aggregation; import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.planprinter.JsonRenderer.JsonRenderedNode; -import static io.trino.sql.planner.planprinter.NodeRepresentation.TypedSymbol; -import static io.trino.sql.planner.planprinter.NodeRepresentation.TypedSymbol.typedSymbol; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -100,21 +99,21 @@ public void testJsonPlan() "8", "Output", ImmutableMap.of("columnNames", "[_col0]"), - ImmutableList.of(typedSymbol("field", INTEGER)), + ImmutableList.of(new Symbol(INTEGER, "field")), ImmutableList.of("_col0 := field"), ImmutableList.of(new PlanNodeStatsAndCostSummary(1, 5, 0, 0, 0)), ImmutableList.of(new JsonRenderedNode( "90", "Limit", ImmutableMap.of("count", "1", "withTies", "", "inputPreSortedBy", "[]"), - ImmutableList.of(typedSymbol("field", INTEGER)), + ImmutableList.of(new Symbol(INTEGER, "field")), ImmutableList.of(), ImmutableList.of(new PlanNodeStatsAndCostSummary(1, 5, 5, 0, 0)), ImmutableList.of(new JsonRenderedNode( "0", "Values", ImmutableMap.of(), - ImmutableList.of(typedSymbol("field", INTEGER)), + ImmutableList.of(new Symbol(INTEGER, "field")), ImmutableList.of("(integer '1')", "(integer '2')"), ImmutableList.of(new PlanNodeStatsAndCostSummary(2, 10, 0, 0, 0)), ImmutableList.of()))))); @@ -132,7 +131,7 @@ public void testAggregationPlan() assertJsonRepresentation( pb -> pb.aggregation(ab -> ab .step(FINAL) - .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new SymbolReference(BIGINT, "x"))), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("sum", BIGINT), aggregation("sum", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)))), new JsonRenderedNode( @@ -143,14 +142,14 @@ public void testAggregationPlan() "keys", "[y, z]", "hash", "[]"), ImmutableList.of( - typedSymbol("y", BIGINT), - typedSymbol("z", BIGINT), - typedSymbol("sum", BIGINT)), + new Symbol(BIGINT, "y"), + new Symbol(BIGINT, "z"), + new Symbol(BIGINT, "sum")), ImmutableList.of("sum := sum(x)"), ImmutableList.of(), ImmutableList.of(valuesRepresentation( "0", - ImmutableList.of(typedSymbol("x", BIGINT), typedSymbol("y", BIGINT), typedSymbol("z", BIGINT)))))); + ImmutableList.of(new Symbol(BIGINT, "x"), new Symbol(BIGINT, "y"), new Symbol(BIGINT, "z")))))); } @Test @@ -164,7 +163,7 @@ public void testJoinPlan() ImmutableList.of(new JoinNode.EquiJoinClause(pb.symbol("a", BIGINT), pb.symbol("d", BIGINT))), ImmutableList.of(pb.symbol("b", BIGINT)), ImmutableList.of(), - Optional.of(new ComparisonExpression(LESS_THAN, new SymbolReference(BIGINT, "a"), new SymbolReference(BIGINT, "c"))), + Optional.of(new Comparison(LESS_THAN, new Reference(BIGINT, "a"), new Reference(BIGINT, "c"))), Optional.empty(), Optional.empty(), ImmutableMap.of(new DynamicFilterId("DF"), pb.symbol("d", BIGINT))), @@ -172,12 +171,12 @@ public void testJoinPlan() "2", "InnerJoin", ImmutableMap.of("criteria", "(a = d)", "filter", "(a < c)", "hash", "[]"), - ImmutableList.of(typedSymbol("b", BIGINT)), + ImmutableList.of(new Symbol(BIGINT, "b")), ImmutableList.of("dynamicFilterAssignments = {d -> #DF}"), ImmutableList.of(), ImmutableList.of( - valuesRepresentation("0", ImmutableList.of(typedSymbol("a", BIGINT), typedSymbol("b", BIGINT))), - valuesRepresentation("1", ImmutableList.of(typedSymbol("c", BIGINT), typedSymbol("d", BIGINT)))))); + valuesRepresentation("0", ImmutableList.of(new Symbol(BIGINT, "a"), new Symbol(BIGINT, "b"))), + valuesRepresentation("1", ImmutableList.of(new Symbol(BIGINT, "c"), new Symbol(BIGINT, "d")))))); } @Test @@ -196,13 +195,13 @@ public void testSourceFragmentIdsInRemoteSource() "0", "RemoteSource", ImmutableMap.of("sourceFragmentIds", "[1, 2]"), - ImmutableList.of(typedSymbol("a", BIGINT), typedSymbol("b", BIGINT)), + ImmutableList.of(new Symbol(BIGINT, "a"), new Symbol(BIGINT, "b")), ImmutableList.of(), ImmutableList.of(), ImmutableList.of())); } - private static JsonRenderedNode valuesRepresentation(String id, List outputs) + private static JsonRenderedNode valuesRepresentation(String id, List outputs) { return new JsonRenderedNode( id, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java index 227fbab2b51a..730e4902a8ce 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java @@ -26,13 +26,13 @@ import io.trino.spi.connector.CatalogHandle; import io.trino.spi.function.OperatorType; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.BasePlanTest; @@ -51,8 +51,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.IrUtils.combineDisjuncts; import static io.trino.sql.planner.plan.JoinType.INNER; @@ -100,7 +100,7 @@ public void testUnconsumedDynamicFilterInJoin() { PlanNode root = builder.join( INNER, - builder.filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), ordersTableScanNode), + builder.filter(new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), ordersTableScanNode), lineitemTableScanNode, ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), ImmutableList.of(ordersOrderKeySymbol), @@ -148,7 +148,7 @@ public void testUnmatchedDynamicFilter() ordersTableScanNode, builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), @@ -173,7 +173,7 @@ public void testDynamicFilterNotAboveTableScan() INNER, builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(lineitemOrderKeySymbol)), ordersTableScanNode, @@ -201,10 +201,10 @@ public void testUnmatchedNestedDynamicFilter() builder.filter( combineConjuncts( combineDisjuncts( - new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK")), + new IsNull(new Reference(BIGINT, "LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), combineDisjuncts( - new NotExpression(new IsNullPredicate(new SymbolReference(BIGINT, "LINEITEM_OK"))), + new Not(new IsNull(new Reference(BIGINT, "LINEITEM_OK"))), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference()))), lineitemTableScanNode), ImmutableList.of(new JoinNode.EquiJoinClause(ordersOrderKeySymbol, lineitemOrderKeySymbol)), @@ -226,7 +226,7 @@ public void testUnsupportedDynamicFilterExpression() builder.join( INNER, builder.filter( - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L))), + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "LINEITEM_OK"), new Constant(BIGINT, 1L))), lineitemTableScanNode), ordersTableScanNode, ImmutableList.of(new JoinNode.EquiJoinClause(lineitemOrderKeySymbol, ordersOrderKeySymbol)), @@ -248,7 +248,7 @@ public void testUnsupportedCastExpression() builder.join( INNER, builder.filter( - createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Cast(new Cast(new SymbolReference(BIGINT, "LINEITEM_OK"), INTEGER), BIGINT)), + createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, new Cast(new Cast(new Reference(BIGINT, "LINEITEM_OK"), INTEGER), BIGINT)), lineitemTableScanNode), ordersTableScanNode, ImmutableList.of(new JoinNode.EquiJoinClause(lineitemOrderKeySymbol, ordersOrderKeySymbol)), @@ -265,7 +265,7 @@ public void testUnsupportedCastExpression() public void testUnconsumedDynamicFilterInSemiJoin() { PlanNode root = builder.semiJoin( - builder.filter(new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), ordersTableScanNode), + builder.filter(new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), ordersTableScanNode), lineitemTableScanNode, ordersOrderKeySymbol, lineitemOrderKeySymbol, @@ -285,12 +285,12 @@ public void testDynamicFilterConsumedOnFilteringSourceSideInSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), ordersTableScanNode), builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), ordersOrderKeySymbol, @@ -314,7 +314,7 @@ public void testUnmatchedDynamicFilterInSemiJoin() builder.semiJoin( builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), ordersTableScanNode), lineitemTableScanNode, @@ -336,7 +336,7 @@ public void testDynamicFilterNotAboveTableScanWithSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - new ComparisonExpression(GREATER_THAN, new SymbolReference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), + new Comparison(GREATER_THAN, new Reference(INTEGER, "ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(ordersOrderKeySymbol)), lineitemTableScanNode, diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestFilteredAggregations.java b/core/trino-main/src/test/java/io/trino/sql/query/TestFilteredAggregations.java index 119352f60eb2..04d8e4592478 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestFilteredAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestFilteredAggregations.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.FilterNode; @@ -27,8 +27,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.LogicalExpression.Operator.OR; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Logical.Operator.OR; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -126,7 +126,7 @@ public void rewriteAddFilterWithMultipleFilters() "SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FILTER(WHERE custkey > 0) FROM orders", anyTree( filter( - new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference(DOUBLE, "totalprice"), new Constant(DOUBLE, 0.0)), new ComparisonExpression(GREATER_THAN, new SymbolReference(BIGINT, "custkey"), new Constant(BIGINT, 0L)))), + new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(DOUBLE, "totalprice"), new Constant(DOUBLE, 0.0)), new Comparison(GREATER_THAN, new Reference(BIGINT, "custkey"), new Constant(BIGINT, 0L)))), source))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java b/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java index a1a0fc80f8c8..13d45030c442 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java @@ -20,10 +20,10 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.plugin.tpch.TpchPlugin; import io.trino.spi.function.OperatorType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.plan.JoinNode; import io.trino.testing.QueryRunner; import io.trino.testing.StandaloneQueryRunner; @@ -42,9 +42,9 @@ import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; +import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; +import static io.trino.sql.ir.Arithmetic.Operator.SUBTRACT; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; @@ -113,7 +113,7 @@ public void testCorrelatedExistsSubqueriesWithOrPredicateAndNull() anyTree( values("y")), project( - ImmutableMap.of("NON_NULL", expression(TRUE_LITERAL)), + ImmutableMap.of("NON_NULL", expression(TRUE)), values("x")))))); assertions.assertQueryAndPlan( @@ -127,7 +127,7 @@ public void testCorrelatedExistsSubqueriesWithOrPredicateAndNull() anyTree( values("y")), project( - ImmutableMap.of("NON_NULL", expression(TRUE_LITERAL)), + ImmutableMap.of("NON_NULL", expression(TRUE)), values("x")))))); } @@ -223,13 +223,13 @@ public void testCorrelatedSubqueriesWithTopN() .equiCriteria("cast_b", "cast_a") .left( project( - ImmutableMap.of("cast_b", expression(new Cast(new SymbolReference(INTEGER, "b"), createDecimalType(11, 1)))), + ImmutableMap.of("cast_b", expression(new Cast(new Reference(INTEGER, "b"), createDecimalType(11, 1)))), any( values("b")))) .right( anyTree( project( - ImmutableMap.of("cast_a", expression(new Cast(new SymbolReference(INTEGER, "a"), createDecimalType(11, 1)))), + ImmutableMap.of("cast_a", expression(new Cast(new Reference(INTEGER, "a"), createDecimalType(11, 1)))), any( rowNumber( rowBuilder -> rowBuilder @@ -247,7 +247,7 @@ public void testCorrelatedSubqueriesWithTopN() .equiCriteria("expr", "a") .left( project( - ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(SUBTRACT_INTEGER, SUBTRACT, new ArithmeticBinaryExpression(MULTIPLY_INTEGER, MULTIPLY, new SymbolReference(INTEGER, "b"), new SymbolReference(INTEGER, "c")), new Constant(INTEGER, 1L)))), + ImmutableMap.of("expr", expression(new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "b"), new Reference(INTEGER, "c")), new Constant(INTEGER, 1L)))), any( values("b", "c")))) .right( @@ -523,7 +523,7 @@ public void testCorrelatedSubqueriesWithGroupBy() values("t2_b")), anyTree( project( - ImmutableMap.of("NON_NULL", expression(TRUE_LITERAL)), + ImmutableMap.of("NON_NULL", expression(TRUE)), anyTree( aggregation( ImmutableMap.of(), @@ -548,7 +548,7 @@ public void testCorrelatedSubqueriesWithGroupBy() values("t2_b")), anyTree( project( - ImmutableMap.of("NON_NULL", expression(TRUE_LITERAL)), + ImmutableMap.of("NON_NULL", expression(TRUE)), aggregation( ImmutableMap.of(), FINAL, diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java index ea59c6d775d5..961793732091 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java @@ -30,13 +30,13 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -59,9 +59,9 @@ import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -148,7 +148,7 @@ public void testPushdownDisabled() session, any( project( - ImmutableMap.of("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), tableScan(testTable, ImmutableMap.of("col0", "col0"))))); } @@ -198,9 +198,9 @@ public void testDereferencePushdown() format("SELECT col0.x FROM %s WHERE col0.x = col1 + 3 and col0.y = 2", testTable), anyTree( filter( - new LogicalExpression(AND, ImmutableList.of( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "y"), new Constant(BIGINT, 2L)), - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new Cast(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), + new Logical(AND, ImmutableList.of( + new Comparison(EQUAL, new Reference(BIGINT, "y"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Cast(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), source2))); // Projection and predicate pushdown with overlapping columns @@ -217,7 +217,7 @@ public void testDereferencePushdown() format("SELECT col0, col0.y expr_y FROM %s WHERE col0.x = 5", testTable), anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Constant(BIGINT, 5L)), source1))); // Projection and predicate pushdown with joins @@ -226,9 +226,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), - "expr_0", expression(new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), - "expr_0_y", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_x", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), + "expr_0_y", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), join(INNER, builder -> { PlanMatchPattern source = tableScan( table -> { @@ -247,7 +247,7 @@ public void testDereferencePushdown() .left( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Constant(BIGINT, 2L)), source))) .right( anyTree( diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java index adcf6a8d88b3..b00d5a76d84f 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java @@ -16,9 +16,9 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import io.trino.spi.type.Type; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -38,43 +38,43 @@ public AbstractTestExtractSpatial() super(new GeoPlugin()); } - protected FunctionCall containsCall(Expression left, Expression right) + protected Call containsCall(Expression left, Expression right) { return functionCall("st_contains", ImmutableList.of(GEOMETRY, GEOMETRY), ImmutableList.of(left, right)); } - protected FunctionCall distanceCall(Expression left, Expression right) + protected Call distanceCall(Expression left, Expression right) { return functionCall("st_distance", ImmutableList.of(GEOMETRY, GEOMETRY), ImmutableList.of(left, right)); } - protected FunctionCall sphericalDistanceCall(Expression left, Expression right) + protected Call sphericalDistanceCall(Expression left, Expression right) { return functionCall("st_distance", ImmutableList.of(SPHERICAL_GEOGRAPHY, SPHERICAL_GEOGRAPHY), ImmutableList.of(left, right)); } - protected FunctionCall geometryFromTextCall(Symbol symbol) + protected Call geometryFromTextCall(Symbol symbol) { return functionCall("st_geometryfromtext", ImmutableList.of(VARCHAR), ImmutableList.of(symbol.toSymbolReference())); } - protected FunctionCall geometryFromTextCall(String text) + protected Call geometryFromTextCall(String text) { return functionCall("st_geometryfromtext", ImmutableList.of(VARCHAR), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice(text)))); } - protected FunctionCall toSphericalGeographyCall(Symbol symbol) + protected Call toSphericalGeographyCall(Symbol symbol) { return functionCall("to_spherical_geography", ImmutableList.of(GEOMETRY), ImmutableList.of(geometryFromTextCall(symbol))); } - protected FunctionCall toPointCall(Expression x, Expression y) + protected Call toPointCall(Expression x, Expression y) { return functionCall("st_point", ImmutableList.of(BIGINT, BIGINT), ImmutableList.of(x, y)); } - private FunctionCall functionCall(String name, List types, List arguments) + private Call functionCall(String name, List types, List arguments) { - return new FunctionCall(tester().getMetadata().resolveBuiltinFunction(name, fromTypes(types)), arguments); + return new Call(tester().getMetadata().resolveBuiltinFunction(name, fromTypes(types)), arguments); } } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java index d4f96b1e1395..2f6b7303002d 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java @@ -17,12 +17,12 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.ExtractSpatialJoins.ExtractSpatialInnerJoin; import io.trino.sql.planner.iterative.rule.test.RuleBuilder; @@ -35,10 +35,10 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.spatialJoin; @@ -75,9 +75,9 @@ public void testDoesNotFire() Symbol name1 = p.symbol("name_1"); Symbol name2 = p.symbol("name_2"); return p.filter( - LogicalExpression.or( + Logical.or( containsCall(geometryFromTextCall(wkt), point.toSymbolReference()), - new ComparisonExpression(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference())), + new Comparison(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference())), p.join(INNER, p.values(wkt, name1), p.values(point, name2))); }) .doesNotFire(); @@ -91,7 +91,7 @@ public void testDoesNotFire() Symbol name1 = p.symbol("name_1"); Symbol name2 = p.symbol("name_2"); return p.filter( - new NotExpression(containsCall(geometryFromTextCall(wkt), point.toSymbolReference())), + new Not(containsCall(geometryFromTextCall(wkt), point.toSymbolReference())), p.join(INNER, p.values(wkt, name1), p.values(point, name2))); @@ -105,7 +105,7 @@ public void testDoesNotFire() Symbol a = p.symbol("a", GEOMETRY); Symbol b = p.symbol("b", GEOMETRY); return p.filter( - new ComparisonExpression(GREATER_THAN, distanceCall(a.toSymbolReference(), b.toSymbolReference()), new Constant(INTEGER, 5L)), + new Comparison(GREATER_THAN, distanceCall(a.toSymbolReference(), b.toSymbolReference()), new Constant(INTEGER, 5L)), p.join(INNER, p.values(a), p.values(b))); @@ -119,7 +119,7 @@ public void testDoesNotFire() Symbol a = p.symbol("a", SPHERICAL_GEOGRAPHY); Symbol b = p.symbol("b", SPHERICAL_GEOGRAPHY); return p.filter( - new ComparisonExpression(LESS_THAN, sphericalDistanceCall(a.toSymbolReference(), b.toSymbolReference()), new Constant(INTEGER, 5L)), + new Comparison(LESS_THAN, sphericalDistanceCall(a.toSymbolReference(), b.toSymbolReference()), new Constant(INTEGER, 5L)), p.join(INNER, p.values(a), p.values(b))); @@ -133,7 +133,7 @@ public void testDoesNotFire() Symbol wkt = p.symbol("wkt", VARCHAR); Symbol point = p.symbol("point", SPHERICAL_GEOGRAPHY); return p.filter( - new ComparisonExpression(LESS_THAN, sphericalDistanceCall(toSphericalGeographyCall(wkt), point.toSymbolReference()), new Constant(INTEGER, 5L)), + new Comparison(LESS_THAN, sphericalDistanceCall(toSphericalGeographyCall(wkt), point.toSymbolReference()), new Constant(INTEGER, 5L)), p.join(INNER, p.values(wkt), p.values(point))); @@ -158,7 +158,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), values(ImmutableMap.of("a", 0)), values(ImmutableMap.of("b", 0)))); @@ -171,8 +171,8 @@ public void testConvertToSpatialJoin() Symbol name1 = p.symbol("name_1"); Symbol name2 = p.symbol("name_2"); return p.filter( - LogicalExpression.and( - new ComparisonExpression(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), + Logical.and( + new Comparison(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), containsCall(a.toSymbolReference(), b.toSymbolReference())), p.join(INNER, p.values(a, name1), @@ -180,7 +180,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(VARCHAR, "name_1"), new SymbolReference(VARCHAR, "name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(VARCHAR, "name_1"), new Reference(VARCHAR, "name_2")), new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))))), values(ImmutableMap.of("a", 0, "name_1", 1)), values(ImmutableMap.of("b", 0, "name_2", 1)))); @@ -193,7 +193,7 @@ public void testConvertToSpatialJoin() Symbol b1 = p.symbol("b1"); Symbol b2 = p.symbol("b2"); return p.filter( - LogicalExpression.and( + Logical.and( containsCall(a1.toSymbolReference(), b1.toSymbolReference()), containsCall(a2.toSymbolReference(), b2.toSymbolReference())), p.join(INNER, @@ -202,7 +202,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a1"), new SymbolReference(GEOMETRY, "b1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a2"), new SymbolReference(GEOMETRY, "b2"))))), + new Logical(AND, ImmutableList.of(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a1"), new Reference(GEOMETRY, "b1"))), new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a2"), new Reference(GEOMETRY, "b2"))))), values(ImmutableMap.of("a1", 0, "a2", 1)), values(ImmutableMap.of("b1", 0, "b2", 1)))); } @@ -223,8 +223,8 @@ public void testPushDownFirstArgument() }) .matches( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0))), values(ImmutableMap.of("point", 0)))); @@ -258,9 +258,9 @@ public void testPushDownSecondArgument() }) .matches( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "polygon"), new SymbolReference(GEOMETRY, "st_point"))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "polygon"), new Reference(GEOMETRY, "st_point"))), values(ImmutableMap.of("polygon", 0)), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); assertRuleApplication() @@ -294,10 +294,10 @@ public void testPushDownBothArguments() }) .matches( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); } @@ -317,10 +317,10 @@ public void testPushDownOppositeOrder() p.values(wkt))); }) .matches( - spatialJoin(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + spatialJoin(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0))))); } @@ -336,8 +336,8 @@ public void testPushDownAnd() Symbol name1 = p.symbol("name_1"); Symbol name2 = p.symbol("name_2"); return p.filter( - LogicalExpression.and( - new ComparisonExpression(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), + Logical.and( + new Comparison(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), containsCall(geometryFromTextCall(wkt), toPointCall(lng.toSymbolReference(), lat.toSymbolReference()))), p.join(INNER, p.values(wkt, name1), @@ -345,10 +345,10 @@ public void testPushDownAnd() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(VARCHAR, "name_1"), new SymbolReference(VARCHAR, "name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(VARCHAR, "name_1"), new Reference(VARCHAR, "name_2")), new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0, "name_1", 1))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1, "name_2", 2))))); // Multiple spatial functions - only the first one is being processed @@ -360,7 +360,7 @@ public void testPushDownAnd() Symbol geometry1 = p.symbol("geometry1"); Symbol geometry2 = p.symbol("geometry2"); return p.filter( - LogicalExpression.and( + Logical.and( containsCall(geometryFromTextCall(wkt1), geometry1.toSymbolReference()), containsCall(geometryFromTextCall(wkt2), geometry2.toSymbolReference())), p.join(INNER, @@ -369,8 +369,8 @@ public void testPushDownAnd() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "geometry1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(GEOMETRY, "wkt2"))), new SymbolReference(GEOMETRY, "geometry2"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(GEOMETRY, "wkt1"))))), + new Logical(AND, ImmutableList.of(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "geometry1"))), new Call(ST_CONTAINS, ImmutableList.of(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(GEOMETRY, "wkt2"))), new Reference(GEOMETRY, "geometry2"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(GEOMETRY, "wkt1"))))), values(ImmutableMap.of("wkt1", 0, "wkt2", 1))), values(ImmutableMap.of("geometry1", 0, "geometry2", 1)))); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java index 298c986c7cb9..5e7b10b2c755 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java @@ -17,12 +17,12 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.ExtractSpatialJoins; import io.trino.sql.planner.iterative.rule.test.RuleBuilder; @@ -35,8 +35,8 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.spatialLeftJoin; @@ -77,9 +77,9 @@ public void testDoesNotFire() return p.join(LEFT, p.values(wkt, name1), p.values(point, name2), - LogicalExpression.or( + Logical.or( containsCall(geometryFromTextCall(wkt), point.toSymbolReference()), - new ComparisonExpression(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()))); + new Comparison(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()))); }) .doesNotFire(); @@ -94,7 +94,7 @@ public void testDoesNotFire() return p.join(LEFT, p.values(wkt, name1), p.values(point, name2), - new NotExpression(containsCall(geometryFromTextCall(wkt), point.toSymbolReference()))); + new Not(containsCall(geometryFromTextCall(wkt), point.toSymbolReference()))); }) .doesNotFire(); @@ -107,7 +107,7 @@ public void testDoesNotFire() return p.join(LEFT, p.values(a), p.values(b), - new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, + new Comparison(Comparison.Operator.GREATER_THAN, distanceCall(a.toSymbolReference(), b.toSymbolReference()), new Constant(INTEGER, 5L))); }) @@ -122,7 +122,7 @@ public void testDoesNotFire() return p.join(LEFT, p.values(a), p.values(b), - new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, + new Comparison(Comparison.Operator.GREATER_THAN, sphericalDistanceCall(a.toSymbolReference(), b.toSymbolReference()), new Constant(INTEGER, 5L))); }) @@ -137,7 +137,7 @@ public void testDoesNotFire() return p.join(LEFT, p.values(wkt), p.values(point), - new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, + new Comparison(Comparison.Operator.GREATER_THAN, sphericalDistanceCall(toSphericalGeographyCall(wkt), point.toSymbolReference()), new Constant(INTEGER, 5L))); }) @@ -160,7 +160,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), values(ImmutableMap.of("a", 0)), values(ImmutableMap.of("b", 0)))); @@ -175,13 +175,13 @@ public void testConvertToSpatialJoin() return p.join(LEFT, p.values(a, name1), p.values(b, name2), - LogicalExpression.and( - new ComparisonExpression(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), + Logical.and( + new Comparison(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), containsCall(a.toSymbolReference(), b.toSymbolReference()))); }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(VARCHAR, "name_1"), new SymbolReference(VARCHAR, "name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(VARCHAR, "name_1"), new Reference(VARCHAR, "name_2")), new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))))), values(ImmutableMap.of("a", 0, "name_1", 1)), values(ImmutableMap.of("b", 0, "name_2", 1)))); @@ -196,13 +196,13 @@ public void testConvertToSpatialJoin() return p.join(LEFT, p.values(a1, a2), p.values(b1, b2), - LogicalExpression.and( + Logical.and( containsCall(a1.toSymbolReference(), b1.toSymbolReference()), containsCall(a2.toSymbolReference(), b2.toSymbolReference()))); }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a1"), new SymbolReference(GEOMETRY, "b1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "a2"), new SymbolReference(GEOMETRY, "b2"))))), + new Logical(AND, ImmutableList.of(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a1"), new Reference(GEOMETRY, "b1"))), new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "a2"), new Reference(GEOMETRY, "b2"))))), values(ImmutableMap.of("a1", 0, "a2", 1)), values(ImmutableMap.of("b1", 0, "b2", 1)))); } @@ -222,8 +222,8 @@ public void testPushDownFirstArgument() }) .matches( spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0))), values(ImmutableMap.of("point", 0)))); @@ -255,9 +255,9 @@ public void testPushDownSecondArgument() }) .matches( spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "polygon"), new SymbolReference(GEOMETRY, "st_point"))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "polygon"), new Reference(GEOMETRY, "st_point"))), values(ImmutableMap.of("polygon", 0)), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); assertRuleApplication() @@ -289,10 +289,10 @@ public void testPushDownBothArguments() }) .matches( spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); } @@ -312,9 +312,9 @@ public void testPushDownOppositeOrder() }) .matches( spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0))))); } @@ -332,16 +332,16 @@ public void testPushDownAnd() return p.join(LEFT, p.values(wkt, name1), p.values(lat, lng, name2), - LogicalExpression.and( - new ComparisonExpression(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), + Logical.and( + new Comparison(NOT_EQUAL, name1.toSymbolReference(), name2.toSymbolReference()), containsCall(geometryFromTextCall(wkt), toPointCall(lng.toSymbolReference(), lat.toSymbolReference())))); }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference(VARCHAR, "name_1"), new SymbolReference(VARCHAR, "name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(VARCHAR, "wkt"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(VARCHAR, "name_1"), new Reference(VARCHAR, "name_2")), new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(VARCHAR, "wkt"))))), values(ImmutableMap.of("wkt", 0, "name_1", 1))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1, "name_2", 2))))); // Multiple spatial functions - only the first one is being processed @@ -355,14 +355,14 @@ public void testPushDownAnd() return p.join(LEFT, p.values(wkt1, wkt2), p.values(geometry1, geometry2), - LogicalExpression.and( + Logical.and( containsCall(geometryFromTextCall(wkt1), geometry1.toSymbolReference()), containsCall(geometryFromTextCall(wkt2), geometry2.toSymbolReference()))); }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "geometry1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(GEOMETRY, "wkt2"))), new SymbolReference(GEOMETRY, "geometry2"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference(GEOMETRY, "wkt1"))))), + new Logical(AND, ImmutableList.of(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "geometry1"))), new Call(ST_CONTAINS, ImmutableList.of(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(GEOMETRY, "wkt2"))), new Reference(GEOMETRY, "geometry2"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Reference(GEOMETRY, "wkt1"))))), values(ImmutableMap.of("wkt1", 0, "wkt2", 1))), values(ImmutableMap.of("geometry1", 0, "geometry2", 1)))); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java index 7a32164e13b2..28110a5e46a2 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java @@ -17,9 +17,9 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.PruneSpatialJoinChildrenColumns; @@ -32,7 +32,7 @@ import static io.trino.plugin.geospatial.GeometryType.GEOMETRY; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.spatialJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -57,18 +57,18 @@ public void testPruneOneChild() p.values(a, unused), p.values(b, r), ImmutableList.of(a, b, r), - new ComparisonExpression( + new Comparison( LESS_THAN_OR_EQUAL, - new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), + new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), r.toSymbolReference())); }) .matches( spatialJoin( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), new SymbolReference(DOUBLE, "r")), + new Comparison(LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), new Reference(DOUBLE, "r")), Optional.empty(), Optional.of(ImmutableList.of("a", "b", "r")), strictProject( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(GEOMETRY, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(GEOMETRY, "a"))), values("a", "unused")), values("b", "r"))); } @@ -88,21 +88,21 @@ public void testPruneBothChildren() p.values(a, unusedLeft), p.values(b, r, unusedRight), ImmutableList.of(a, b, r), - new ComparisonExpression( + new Comparison( LESS_THAN_OR_EQUAL, - new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), + new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), r.toSymbolReference())); }) .matches( spatialJoin( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), new SymbolReference(DOUBLE, "r")), + new Comparison(LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), new Reference(DOUBLE, "r")), Optional.empty(), Optional.of(ImmutableList.of("a", "b", "r")), strictProject( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(GEOMETRY, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(GEOMETRY, "a"))), values("a", "unused_left")), strictProject( - ImmutableMap.of("b", PlanMatchPattern.expression(new SymbolReference(GEOMETRY, "b")), "r", PlanMatchPattern.expression(new SymbolReference(DOUBLE, "r"))), + ImmutableMap.of("b", PlanMatchPattern.expression(new Reference(GEOMETRY, "b")), "r", PlanMatchPattern.expression(new Reference(DOUBLE, "r"))), values("b", "r", "unused_right")))); } @@ -120,7 +120,7 @@ public void testDoNotPruneOneOutputOrFilterSymbols() p.values(a), p.values(b, r, output), ImmutableList.of(output), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), new SymbolReference(DOUBLE, "r"))); + new Comparison(LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), new Reference(DOUBLE, "r"))); }) .doesNotFire(); } @@ -140,7 +140,7 @@ public void testDoNotPrunePartitionSymbols() p.values(a, leftPartitionSymbol), p.values(b, r, rightPartitionSymbol), ImmutableList.of(a, b, r), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), new SymbolReference(DOUBLE, "r")), + new Comparison(LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), new Reference(DOUBLE, "r")), Optional.of(leftPartitionSymbol), Optional.of(rightPartitionSymbol), Optional.of("some nice kdb tree")); diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java index 33b810999af9..c3461d0691bc 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java @@ -17,9 +17,9 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.ComparisonExpression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.PruneSpatialJoinColumns; @@ -33,7 +33,7 @@ import static io.trino.plugin.geospatial.GeometryType.GEOMETRY; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.spatialJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -59,16 +59,16 @@ public void notAllOutputsReferenced() p.values(a), p.values(b, r), ImmutableList.of(a, b, r), - new ComparisonExpression( + new Comparison( LESS_THAN_OR_EQUAL, - new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), + new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), r.toSymbolReference()))); }) .matches( strictProject( - ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference(GEOMETRY, "a"))), + ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(GEOMETRY, "a"))), spatialJoin( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), new SymbolReference(DOUBLE, "r")), + new Comparison(LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), new Reference(DOUBLE, "r")), Optional.empty(), Optional.of(ImmutableList.of("a")), values("a"), @@ -90,7 +90,7 @@ public void allOutputsReferenced() p.values(a), p.values(b, r), ImmutableList.of(a, b, r), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference(GEOMETRY, "a"), new SymbolReference(GEOMETRY, "b"))), new SymbolReference(DOUBLE, "r")))); + new Comparison(LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GEOMETRY, "a"), new Reference(GEOMETRY, "b"))), new Reference(DOUBLE, "r")))); }) .doesNotFire(); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java index 2ce154295249..5700880a2c3d 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java @@ -17,9 +17,9 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.sql.ir.Call; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -54,7 +54,7 @@ public void testDoesNotFire() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.SINGLE) - .addAggregation(p.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new SymbolReference(GEOMETRY, "geometry"), new SymbolReference(INTEGER, "n"))), ImmutableList.of(GEOMETRY, INTEGER)) + .addAggregation(p.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new Reference(GEOMETRY, "geometry"), new Reference(INTEGER, "n"))), ImmutableList.of(GEOMETRY, INTEGER)) .source(p.values(p.symbol("geometry"), p.symbol("n"))))) .doesNotFire(); } @@ -66,28 +66,28 @@ public void test() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.SINGLE) - .addAggregation(p.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new SymbolReference(GEOMETRY, "geometry"))), ImmutableList.of(GEOMETRY)) + .addAggregation(p.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new Reference(GEOMETRY, "geometry"))), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", aggregationFunction("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression(new Constant(INTEGER, 100L)), - "envelope", expression(new FunctionCall(ST_ENVELOPE, ImmutableList.of(new SymbolReference(GEOMETRY, "geometry"))))), + "envelope", expression(new Call(ST_ENVELOPE, ImmutableList.of(new Reference(GEOMETRY, "geometry"))))), values("geometry")))); assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.SINGLE) - .addAggregation(p.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new SymbolReference(GEOMETRY, "envelope"))), ImmutableList.of(GEOMETRY)) + .addAggregation(p.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new Reference(GEOMETRY, "envelope"))), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("envelope"))))) .matches( aggregation( ImmutableMap.of("sp", aggregationFunction("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression(new Constant(INTEGER, 100L)), - "envelope", expression(new FunctionCall(ST_ENVELOPE, ImmutableList.of(new SymbolReference(GEOMETRY, "geometry"))))), + "envelope", expression(new Call(ST_ENVELOPE, ImmutableList.of(new Reference(GEOMETRY, "geometry"))))), values("geometry")))); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java index de0fc53c573d..551221a0d211 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java @@ -26,16 +26,16 @@ import io.trino.plugin.tpch.TpchConnectorFactory; import io.trino.spi.TrinoException; import io.trino.spi.type.Type; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SearchedCaseExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; @@ -59,11 +59,11 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; @@ -124,11 +124,11 @@ public void testSpatialJoinContains() "WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // Verify that projections generated by the ExtractSpatialJoins rule @@ -137,15 +137,15 @@ public void testSpatialJoinContains() "FROM (SELECT length(name), * FROM points), (SELECT length(name), * FROM polygons) " + "WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( - spatialJoin(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), + spatialJoin(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), project(ImmutableMap.of( - "st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat")))), - "length", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference(VARCHAR, "name"))))), + "st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat")))), + "length", expression(new Call(LENGTH, ImmutableList.of(new Reference(VARCHAR, "name"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name"))), anyTree( project(ImmutableMap.of( - "st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR)))), - "length_2", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference(VARCHAR, "name_2"))))), + "st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR)))), + "length_2", expression(new Call(LENGTH, ImmutableList.of(new Reference(VARCHAR, "name_2"))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name"))))))); // distributed @@ -155,18 +155,18 @@ public void testSpatialJoinContains() withSpatialPartitioning("kdb_tree"), anyTree( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("partitions_a", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "st_point"))))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("partitions_a", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "st_point"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name")))))), anyTree( unnest( - project(ImmutableMap.of("partitions_b", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "st_geometryfromtext"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("partitions_b", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "st_geometryfromtext"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name"))))))))); } @@ -179,11 +179,11 @@ public void testSpatialJoinWithin() "WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( spatialJoin( - new FunctionCall(ST_WITHIN, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + new Call(ST_WITHIN, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // Verify that projections generated by the ExtractSpatialJoins rule @@ -192,15 +192,15 @@ public void testSpatialJoinWithin() "FROM (SELECT length(name), * FROM points), (SELECT length(name), * FROM polygons) " + "WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( - spatialJoin(new FunctionCall(ST_WITHIN, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), + spatialJoin(new Call(ST_WITHIN, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), project(ImmutableMap.of( - "st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat")))), - "length", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference(VARCHAR, "name"))))), + "st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat")))), + "length", expression(new Call(LENGTH, ImmutableList.of(new Reference(VARCHAR, "name"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name"))), anyTree( project(ImmutableMap.of( - "st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR)))), - "length_2", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference(VARCHAR, "name_2"))))), + "st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR)))), + "length_2", expression(new Call(LENGTH, ImmutableList.of(new Reference(VARCHAR, "name_2"))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name"))))))); // distributed @@ -210,18 +210,18 @@ public void testSpatialJoinWithin() withSpatialPartitioning("kdb_tree"), anyTree( spatialJoin( - new FunctionCall(ST_WITHIN, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), + new Call(ST_WITHIN, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("partitions_a", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "st_point"))))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + project(ImmutableMap.of("partitions_a", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "st_point"))))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name")))))), anyTree( unnest( - project(ImmutableMap.of("partitions_b", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "st_geometryfromtext"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("partitions_b", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "st_geometryfromtext"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))))); } @@ -301,11 +301,11 @@ public void testSpatialJoinIntersects() "WHERE ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", anyTree( spatialJoin( - new FunctionCall(ST_INTERSECTS, ImmutableList.of(new SymbolReference(GEOMETRY, "geometry_a"), new SymbolReference(GEOMETRY, "geometry_b"))), - project(ImmutableMap.of("geometry_a", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_a"), VARCHAR))))), + new Call(ST_INTERSECTS, ImmutableList.of(new Reference(GEOMETRY, "geometry_a"), new Reference(GEOMETRY, "geometry_b"))), + project(ImmutableMap.of("geometry_a", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_a"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name"))), anyTree( - project(ImmutableMap.of("geometry_b", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_b"), VARCHAR))))), + project(ImmutableMap.of("geometry_b", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_b"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name"))))))); // distributed @@ -315,15 +315,15 @@ public void testSpatialJoinIntersects() withSpatialPartitioning("default.kdb_tree"), anyTree( spatialJoin( - new FunctionCall(ST_INTERSECTS, ImmutableList.of(new SymbolReference(GEOMETRY, "geometry_a"), new SymbolReference(GEOMETRY, "geometry_b"))), Optional.of(KDB_TREE_JSON), Optional.empty(), + new Call(ST_INTERSECTS, ImmutableList.of(new Reference(GEOMETRY, "geometry_a"), new Reference(GEOMETRY, "geometry_b"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("partitions_a", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "geometry_a"))))), - project(ImmutableMap.of("geometry_a", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_a"), VARCHAR))))), + project(ImmutableMap.of("partitions_a", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "geometry_a"))))), + project(ImmutableMap.of("geometry_a", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_a"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name")))))), anyTree( - project(ImmutableMap.of("partitions_b", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "geometry_b"))))), - project(ImmutableMap.of("geometry_b", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_b"), VARCHAR))))), + project(ImmutableMap.of("partitions_b", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "geometry_b"))))), + project(ImmutableMap.of("geometry_b", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_b"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name")))))))); } @@ -335,7 +335,7 @@ public void testNotContains() "WHERE NOT ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( filter( - new NotExpression(new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))), new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat")))))), + new Not(new Call(ST_CONTAINS, ImmutableList.of(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))), new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat")))))), join(INNER, builder -> builder .left(tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))) .right( @@ -353,21 +353,21 @@ public void testNotIntersects() " WHERE NOT ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", singleRow()), anyTree( filter( - new NotExpression( + new Not( functionCall("ST_Intersects", ImmutableList.of(GEOMETRY, GEOMETRY), ImmutableList.of( - functionCall("ST_GeometryFromText", ImmutableList.of(VARCHAR), ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_a"), VARCHAR))), - functionCall("ST_GeometryFromText", ImmutableList.of(VARCHAR), ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_b"), VARCHAR)))))), + functionCall("ST_GeometryFromText", ImmutableList.of(VARCHAR), ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_a"), VARCHAR))), + functionCall("ST_GeometryFromText", ImmutableList.of(VARCHAR), ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_b"), VARCHAR)))))), join(INNER, builder -> builder .left( project( ImmutableMap.of( - "wkt_a", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), + "wkt_a", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN_OR_EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), "name_a", expression(new Constant(createVarcharType(1), Slices.utf8Slice("a")))), singleRow())) .right( any(project( ImmutableMap.of( - "wkt_b", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), + "wkt_b", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN_OR_EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), "name_b", expression(new Constant(createVarcharType(1), Slices.utf8Slice("a")))), singleRow()))))))); } @@ -381,7 +381,7 @@ public void testContainsWithEquiClause() anyTree( join(INNER, builder -> builder .equiCriteria("name_a", "name_b") - .filter(new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))), new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat")))))) + .filter(new Call(ST_CONTAINS, ImmutableList.of(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))), new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat")))))) .left( anyTree( tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name")))) @@ -399,7 +399,7 @@ public void testIntersectsWithEquiClause() anyTree( join(INNER, builder -> builder .equiCriteria("name_a", "name_b") - .filter(new FunctionCall(ST_INTERSECTS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_a"), VARCHAR))), new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt_b"), VARCHAR)))))) + .filter(new Call(ST_INTERSECTS, ImmutableList.of(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_a"), VARCHAR))), new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt_b"), VARCHAR)))))) .left( anyTree( tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name")))) @@ -416,11 +416,11 @@ public void testSpatialLeftJoins() "ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // deterministic extra join predicate @@ -429,11 +429,11 @@ public void testSpatialLeftJoins() "ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) AND a.name <> b.name", anyTree( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), new ComparisonExpression(NOT_EQUAL, new SymbolReference(VARCHAR, "name_a"), new SymbolReference(VARCHAR, "name_b")))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + new Logical(AND, ImmutableList.of(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), new Comparison(NOT_EQUAL, new Reference(VARCHAR, "name_a"), new Reference(VARCHAR, "name_b")))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // non-deterministic extra join predicate @@ -442,11 +442,11 @@ public void testSpatialLeftJoins() "ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) AND rand() < 0.5", anyTree( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), new ComparisonExpression(LESS_THAN, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.5)))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + new Logical(AND, ImmutableList.of(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), new Comparison(LESS_THAN, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.5)))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // filter over join @@ -456,13 +456,13 @@ public void testSpatialLeftJoins() "WHERE concat(a.name, b.name) is null", anyTree( filter( - new IsNullPredicate(new FunctionCall(CONCAT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_a"), VARCHAR), new Cast(new SymbolReference(VARCHAR, "name_b"), VARCHAR)))), + new IsNull(new Call(CONCAT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_a"), VARCHAR), new Cast(new Reference(VARCHAR, "name_b"), VARCHAR)))), spatialLeftJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "st_geometryfromtext"), new SymbolReference(GEOMETRY, "st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference(DOUBLE, "lng"), new SymbolReference(DOUBLE, "lat"))))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "st_geometryfromtext"), new Reference(GEOMETRY, "st_point"))), + project(ImmutableMap.of("st_point", expression(new Call(ST_POINT, ImmutableList.of(new Reference(DOUBLE, "lng"), new Reference(DOUBLE, "lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))))))); } @@ -475,19 +475,19 @@ public void testDistributedSpatialJoinOverUnion() "WHERE ST_Contains(ST_GeometryFromText(a.name), ST_GeometryFromText(b.name))", withSpatialPartitioning("kdb_tree"), anyTree( - spatialJoin(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "g1"), new SymbolReference(GEOMETRY, "g3"))), Optional.of(KDB_TREE_JSON), Optional.empty(), + spatialJoin(new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "g1"), new Reference(GEOMETRY, "g3"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest(exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, - project(ImmutableMap.of("p1", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "g1"))))), - project(ImmutableMap.of("g1", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_a1"), VARCHAR))))), + project(ImmutableMap.of("p1", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "g1"))))), + project(ImmutableMap.of("g1", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_a1"), VARCHAR))))), tableScan("region", ImmutableMap.of("name_a1", "name")))), - project(ImmutableMap.of("p2", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "g2"))))), - project(ImmutableMap.of("g2", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_a2"), VARCHAR))))), + project(ImmutableMap.of("p2", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "g2"))))), + project(ImmutableMap.of("g2", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_a2"), VARCHAR))))), tableScan("nation", ImmutableMap.of("name_a2", "name"))))))), anyTree( unnest( - project(ImmutableMap.of("p3", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "g3"))))), - project(ImmutableMap.of("g3", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_b"), VARCHAR))))), + project(ImmutableMap.of("p3", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "g3"))))), + project(ImmutableMap.of("g3", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_b"), VARCHAR))))), tableScan("customer", ImmutableMap.of("name_b", "name"))))))))); // union on the right side @@ -497,21 +497,21 @@ public void testDistributedSpatialJoinOverUnion() withSpatialPartitioning("kdb_tree"), anyTree( spatialJoin( - new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference(GEOMETRY, "g1"), new SymbolReference(GEOMETRY, "g2"))), + new Call(ST_CONTAINS, ImmutableList.of(new Reference(GEOMETRY, "g1"), new Reference(GEOMETRY, "g2"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("p1", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "g1"))))), - project(ImmutableMap.of("g1", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_a"), VARCHAR))))), + project(ImmutableMap.of("p1", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "g1"))))), + project(ImmutableMap.of("g1", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_a"), VARCHAR))))), tableScan("customer", ImmutableMap.of("name_a", "name")))))), anyTree( unnest(exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, - project(ImmutableMap.of("p2", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "g2"))))), - project(ImmutableMap.of("g2", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_b1"), VARCHAR))))), + project(ImmutableMap.of("p2", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "g2"))))), + project(ImmutableMap.of("g2", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_b1"), VARCHAR))))), tableScan("region", ImmutableMap.of("name_b1", "name")))), - project(ImmutableMap.of("p3", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference(GEOMETRY, "g3"))))), - project(ImmutableMap.of("g3", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference(VARCHAR, "name_b2"), VARCHAR))))), + project(ImmutableMap.of("p3", expression(new Call(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new Reference(GEOMETRY, "g3"))))), + project(ImmutableMap.of("g3", expression(new Call(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new Reference(VARCHAR, "name_b2"), VARCHAR))))), tableScan("nation", ImmutableMap.of("name_b2", "name")))))))))); } @@ -531,7 +531,7 @@ private String singleRow(String... columns) private PlanMatchPattern singleRow() { return filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "regionkey"), new Constant(BIGINT, 1L)), + new Comparison(EQUAL, new Reference(BIGINT, "regionkey"), new Constant(BIGINT, 1L)), tableScan("region", ImmutableMap.of("regionkey", "regionkey"))); } @@ -548,8 +548,8 @@ private static String doubleLiteral(double value) return format("%.16E", value); } - private FunctionCall functionCall(String name, List types, List arguments) + private Call functionCall(String name, List types, List arguments) { - return new FunctionCall(getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction(name, fromTypes(types)), arguments); + return new Call(getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction(name, fromTypes(types)), arguments); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java index 601eadb6b354..6815039d7c90 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java @@ -36,13 +36,13 @@ import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; import io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -70,8 +70,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -170,7 +170,7 @@ public void testProjectionPushdown() ImmutableMap.of(p.symbol("struct_of_int", baseType), fullColumn)))) .matches( project( - ImmutableMap.of("expr", expression(new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col"))), + ImmutableMap.of("expr", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col"))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(fullColumn))::equals, TupleDomain.all(), @@ -195,13 +195,13 @@ public void testProjectionPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_deref", BIGINT), new SubscriptExpression(INTEGER, p.symbol("struct_of_int", baseType).toSymbolReference(), new Constant(INTEGER, 1L))), + p.symbol("expr_deref", BIGINT), new Subscript(INTEGER, p.symbol("struct_of_int", baseType).toSymbolReference(), new Constant(INTEGER, 1L))), p.tableScan( table, ImmutableList.of(p.symbol("struct_of_int", baseType)), ImmutableMap.of(p.symbol("struct_of_int", baseType), fullColumn)))) .matches(project( - ImmutableMap.of("expr_deref", expression(new SymbolReference(BIGINT, "struct_of_int#a"))), + ImmutableMap.of("expr_deref", expression(new Reference(BIGINT, "struct_of_int#a"))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, TupleDomain.all(), @@ -226,13 +226,13 @@ public void testPredicatePushdown() tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), p.tableScan( table, ImmutableList.of(p.symbol("a", INTEGER)), ImmutableMap.of(p.symbol("a", INTEGER), column)))) .matches(filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), tableScan( tableHandle -> ((HiveTableHandle) tableHandle).getCompactEffectivePredicate().getDomains().get() .equals(ImmutableMap.of(column, Domain.singleValue(INTEGER, 5L))), @@ -271,7 +271,7 @@ public void testColumnPruningProjectionPushdown() }) .matches( strictProject( - ImmutableMap.of("expr", expression(new SymbolReference(INTEGER, "COLA"))), + ImmutableMap.of("expr", expression(new Reference(INTEGER, "COLA"))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(columnA))::equals, TupleDomain.all(), @@ -312,8 +312,8 @@ public void testPushdownWithDuplicateExpressions() // Test projection pushdown with duplicate column references tester().assertThat(pushProjectionIntoTableScan) .on(p -> { - SymbolReference column = p.symbol("just_bigint", BIGINT).toSymbolReference(); - Expression negation = new ArithmeticNegation(column); + Reference column = p.symbol("just_bigint", BIGINT).toSymbolReference(); + Expression negation = new Negation(column); return p.project( Assignments.of( // The column reference is part of both the assignments @@ -326,8 +326,8 @@ public void testPushdownWithDuplicateExpressions() }) .matches(project( ImmutableMap.of( - "column_ref", expression(new SymbolReference(BIGINT, "just_bigint_0")), - "negated_column_ref", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "just_bigint_0")))), + "column_ref", expression(new Reference(BIGINT, "just_bigint_0")), + "negated_column_ref", expression(new Negation(new Reference(BIGINT, "just_bigint_0")))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(bigintColumn))::equals, TupleDomain.all(), @@ -336,8 +336,8 @@ public void testPushdownWithDuplicateExpressions() // Test Dereference pushdown tester().assertThat(pushProjectionIntoTableScan) .on(p -> { - SubscriptExpression subscript = new SubscriptExpression(BIGINT, p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), new Constant(INTEGER, 1L)); - Expression sum = new ArithmeticBinaryExpression(ADD_INTEGER, ADD, subscript, new Constant(INTEGER, 2L)); + Subscript subscript = new Subscript(BIGINT, p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), new Constant(INTEGER, 1L)); + Expression sum = new Arithmetic(ADD_INTEGER, ADD, subscript, new Constant(INTEGER, 2L)); return p.project( Assignments.of( // The subscript expression instance is part of both the assignments @@ -350,8 +350,8 @@ public void testPushdownWithDuplicateExpressions() }) .matches(project( ImmutableMap.of( - "expr_deref", expression(new SymbolReference(INTEGER, "struct_of_bigint#a")), - "expr_deref_2", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "struct_of_bigint#a"), new Constant(INTEGER, 2L)))), + "expr_deref", expression(new Reference(INTEGER, "struct_of_bigint#a")), + "expr_deref_2", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "struct_of_bigint#a"), new Constant(INTEGER, 2L)))), tableScan( hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, TupleDomain.all(), diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java index c7b4bdbd33a3..e2f6d250a42f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java @@ -24,14 +24,14 @@ import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.spi.function.OperatorType; import io.trino.spi.security.PrincipalType; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.BetweenPredicate; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Between; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; @@ -58,11 +58,11 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; -import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.MODULUS; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -156,7 +156,7 @@ public void testPruneSimplePartitionLikeFilter() "SELECT * FROM table_str_partitioned WHERE str_part LIKE 't%'", output( filter( - new FunctionCall(LIKE, ImmutableList.of(new SymbolReference(VARCHAR, "STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), + new Call(LIKE, ImmutableList.of(new Reference(VARCHAR, "STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), tableScan("table_str_partitioned", Map.of("INT_COL", "int_col", "STR_PART", "str_part"))))); } @@ -178,12 +178,12 @@ public void testPrunePartitionLikeFilter() .left( exchange(REMOTE, REPARTITION, filter( - new FunctionCall(LIKE, ImmutableList.of(new SymbolReference(VARCHAR, "L_STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), + new Call(LIKE, ImmutableList.of(new Reference(VARCHAR, "L_STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), tableScan("table_str_partitioned", Map.of("L_INT_COL", "int_col", "L_STR_PART", "str_part"))))) .right(exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new LogicalExpression(AND, ImmutableList.of(new InPredicate(new SymbolReference(VARCHAR, "R_STR_COL"), ImmutableList.of(new Constant(createVarcharType(5), Slices.utf8Slice("three")), new Constant(createVarcharType(5), Slices.utf8Slice("two")))), new FunctionCall(LIKE, ImmutableList.of(new SymbolReference(VARCHAR, "R_STR_COL"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))))), + new Logical(AND, ImmutableList.of(new In(new Reference(VARCHAR, "R_STR_COL"), ImmutableList.of(new Constant(createVarcharType(5), Slices.utf8Slice("three")), new Constant(createVarcharType(5), Slices.utf8Slice("two")))), new Call(LIKE, ImmutableList.of(new Reference(VARCHAR, "R_STR_COL"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -202,13 +202,13 @@ public void testSubsumePartitionFilter() .left( exchange(REMOTE, REPARTITION, filter( - TRUE_LITERAL, + TRUE, tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new InPredicate(new SymbolReference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L))), + new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -228,13 +228,13 @@ public void testSubsumePartitionPartOfAFilter() .left( exchange(REMOTE, REPARTITION, filter( - new ComparisonExpression(NOT_EQUAL, new SymbolReference(createVarcharType(5), "L_STR_COL"), new Constant(createVarcharType(5), Slices.utf8Slice("three"))), + new Comparison(NOT_EQUAL, new Reference(createVarcharType(5), "L_STR_COL"), new Constant(createVarcharType(5), Slices.utf8Slice("three"))), tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new LogicalExpression(AND, ImmutableList.of(new InPredicate(new SymbolReference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L))), new BetweenPredicate(new SymbolReference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))), + new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L))), new Between(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -254,13 +254,13 @@ public void testSubsumePartitionPartWhenOtherFilterNotConvertibleToTupleDomain() .left( exchange(REMOTE, REPARTITION, filter( - new ComparisonExpression(NOT_EQUAL, new FunctionCall(SUBSTRING, ImmutableList.of(new SymbolReference(VARCHAR, "L_STR_COL"), new Constant(BIGINT, 2L))), new Constant(createVarcharType(5), Slices.utf8Slice("hree"))), + new Comparison(NOT_EQUAL, new Call(SUBSTRING, ImmutableList.of(new Reference(VARCHAR, "L_STR_COL"), new Constant(BIGINT, 2L))), new Constant(createVarcharType(5), Slices.utf8Slice("hree"))), tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new LogicalExpression(AND, ImmutableList.of(new InPredicate(new SymbolReference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L))), new BetweenPredicate(new SymbolReference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))), + new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L))), new Between(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -280,13 +280,13 @@ public void testSubsumePartitionFilterNotConvertibleToTupleDomain() .left( exchange(REMOTE, REPARTITION, filter( - new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "L_INT_PART"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)), + new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "L_INT_PART"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)), tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new LogicalExpression(AND, ImmutableList.of(new InPredicate(new SymbolReference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), new ComparisonExpression(EQUAL, new ArithmeticBinaryExpression(MODULUS_INTEGER, MODULUS, new SymbolReference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)))), + new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), new Comparison(EQUAL, new Arithmetic(MODULUS_INTEGER, MODULUS, new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L)), new Constant(INTEGER, 0L)))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -303,13 +303,13 @@ public void testFilterDerivedFromTableProperties() .left( exchange(REMOTE, REPARTITION, filter( - TRUE_LITERAL, + TRUE, tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new InPredicate(new SymbolReference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), + new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L), new Constant(INTEGER, 3L), new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -324,7 +324,7 @@ public void testQueryScanningForTooManyPartitions() .equiCriteria("L_INT_PART", "R_INT_COL") .left( filter( - TRUE_LITERAL, + TRUE, tableScan("table_int_with_too_many_partitions", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col")))) .right( exchange(LOCAL, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java index cbd0cc60d376..843f9c11fcac 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java @@ -33,13 +33,13 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.testing.PlanTester; import org.junit.jupiter.api.AfterAll; @@ -58,9 +58,9 @@ import static io.trino.plugin.hive.TestingHiveUtils.getConnectorService; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -134,8 +134,8 @@ public void testPushdownDisabled() any( project( ImmutableMap.of( - "expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), - "expr_2", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + "expr", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), + "expr_2", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), tableScan(testTable, ImmutableMap.of("col0", "col0"))))); } @@ -179,7 +179,7 @@ public void testDereferencePushdown() format("SELECT col0.x FROM %s WHERE col0.x = col1 + 3 and col0.y = 2", testTable), anyTree( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "col0_y"), new Constant(BIGINT, 2L)), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "col0_x"), new Cast(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "col0_y"), new Constant(BIGINT, 2L)), new Comparison(EQUAL, new Reference(BIGINT, "col0_x"), new Cast(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), tableScan( table -> { HiveTableHandle hiveTableHandle = (HiveTableHandle) table; @@ -196,7 +196,7 @@ public void testDereferencePushdown() format("SELECT col0, col0.y expr_y FROM %s WHERE col0.x = 5", testTable), anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "col0_x"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "col0_x"), new Constant(BIGINT, 5L)), tableScan( table -> { HiveTableHandle hiveTableHandle = (HiveTableHandle) table; @@ -214,15 +214,15 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), - "expr_0", expression(new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), - "expr_0_y", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_x", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), + "expr_0_y", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), join(INNER, builder -> builder .equiCriteria("t_expr_1", "s_expr_1") .left( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "expr_0_x"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "expr_0_x"), new Constant(BIGINT, 2L)), tableScan( table -> ((HiveTableHandle) table).getCompactEffectivePredicate().getDomains().get() .equals(ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 2L))), diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java index 89e356f4ac57..65d4627f185c 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java @@ -30,13 +30,13 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; -import io.trino.sql.ir.ArithmeticBinaryExpression; +import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.testing.PlanTester; import org.junit.jupiter.api.AfterAll; @@ -55,9 +55,9 @@ import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; -import static io.trino.sql.ir.LogicalExpression.Operator.AND; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Logical.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -140,7 +140,7 @@ public void testPushdownDisabled() session, any( project( - ImmutableMap.of("expr", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), tableScan(testTable, ImmutableMap.of("col0", "col0"))))); } @@ -193,7 +193,7 @@ public void testDereferencePushdown() format("SELECT col0.x FROM %s WHERE col0.x = col1 + 3 and col0.y = 2", testTable), anyTree( filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "y"), new Constant(BIGINT, 2L)), new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new Cast(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), + new Logical(AND, ImmutableList.of(new Comparison(EQUAL, new Reference(BIGINT, "y"), new Constant(BIGINT, 2L)), new Comparison(EQUAL, new Reference(BIGINT, "x"), new Cast(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "col1"), new Constant(INTEGER, 3L)), BIGINT)))), tableScan( table -> { IcebergTableHandle icebergTableHandle = (IcebergTableHandle) table; @@ -209,7 +209,7 @@ public void testDereferencePushdown() format("SELECT col0, col0.y expr_y FROM %s WHERE col0.x = 5", testTable), anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Constant(BIGINT, 5L)), tableScan( table -> { IcebergTableHandle icebergTableHandle = (IcebergTableHandle) table; @@ -226,9 +226,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), - "expr_0", expression(new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), - "expr_0_y", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_x", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), + "expr_0_y", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), join(INNER, builder -> builder .equiCriteria("s_expr_1", "t_expr_1") .left( @@ -240,7 +240,7 @@ public void testDereferencePushdown() .right( anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new Constant(BIGINT, 2L)), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Constant(BIGINT, 2L)), tableScan( table -> { IcebergTableHandle icebergTableHandle = (IcebergTableHandle) table; diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java index 1028b3500311..ba7edd2b05e3 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java @@ -39,13 +39,13 @@ import io.trino.spi.security.PrincipalType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; import io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -71,8 +71,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -195,7 +195,7 @@ public void testProjectionPushdown() ImmutableMap.of(p.symbol("struct_of_int", baseType), fullColumn)))) .matches( project( - ImmutableMap.of("expr", expression(new SymbolReference(INTEGER, "col"))), + ImmutableMap.of("expr", expression(new Reference(INTEGER, "col"))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(fullColumn))::equals, TupleDomain.all(), @@ -220,13 +220,13 @@ public void testProjectionPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_deref", BIGINT), new SubscriptExpression(BIGINT, p.symbol("struct_of_int", baseType).toSymbolReference(), new Constant(INTEGER, 1L))), + p.symbol("expr_deref", BIGINT), new Subscript(BIGINT, p.symbol("struct_of_int", baseType).toSymbolReference(), new Constant(INTEGER, 1L))), p.tableScan( table, ImmutableList.of(p.symbol("struct_of_int", baseType)), ImmutableMap.of(p.symbol("struct_of_int", baseType), fullColumn)))) .matches(project( - ImmutableMap.of("expr_deref", expression(new SymbolReference(BIGINT, "struct_of_int#a"))), + ImmutableMap.of("expr_deref", expression(new Reference(BIGINT, "struct_of_int#a"))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, TupleDomain.all(), @@ -271,13 +271,13 @@ public void testPredicatePushdown() tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), p.tableScan( table, ImmutableList.of(p.symbol("a", INTEGER)), ImmutableMap.of(p.symbol("a", INTEGER), column)))) .matches(filter( - new ComparisonExpression(EQUAL, new SymbolReference(INTEGER, "a"), new Constant(INTEGER, 5L)), + new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 5L)), tableScan( tableHandle -> ((IcebergTableHandle) tableHandle).getUnenforcedPredicate().getDomains().get() .equals(ImmutableMap.of(column, Domain.singleValue(INTEGER, 5L))), @@ -335,7 +335,7 @@ public void testColumnPruningProjectionPushdown() }) .matches( strictProject( - ImmutableMap.of("expr", expression(new SymbolReference(INTEGER, "COLA"))), + ImmutableMap.of("expr", expression(new Reference(INTEGER, "COLA"))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(columnA))::equals, TupleDomain.all(), @@ -390,8 +390,8 @@ public void testPushdownWithDuplicateExpressions() // Test projection pushdown with duplicate column references tester().assertThat(pushProjectionIntoTableScan) .on(p -> { - SymbolReference column = p.symbol("just_bigint", BIGINT).toSymbolReference(); - Expression negation = new ArithmeticNegation(column); + Reference column = p.symbol("just_bigint", BIGINT).toSymbolReference(); + Expression negation = new Negation(column); return p.project( Assignments.of( // The column reference is part of both the assignments @@ -404,8 +404,8 @@ public void testPushdownWithDuplicateExpressions() }) .matches(project( ImmutableMap.of( - "column_ref", expression(new SymbolReference(BIGINT, "just_bigint_0")), - "negated_column_ref", expression(new ArithmeticNegation(new SymbolReference(BIGINT, "just_bigint_0")))), + "column_ref", expression(new Reference(BIGINT, "just_bigint_0")), + "negated_column_ref", expression(new Negation(new Reference(BIGINT, "just_bigint_0")))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(bigintColumn))::equals, TupleDomain.all(), @@ -414,8 +414,8 @@ public void testPushdownWithDuplicateExpressions() // Test Dereference pushdown tester().assertThat(pushProjectionIntoTableScan) .on(p -> { - SubscriptExpression subscript = new SubscriptExpression(BIGINT, p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), new Constant(INTEGER, 1L)); - Expression sum = new ArithmeticBinaryExpression(ADD_INTEGER, ADD, subscript, new Constant(INTEGER, 2L)); + Subscript subscript = new Subscript(BIGINT, p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), new Constant(INTEGER, 1L)); + Expression sum = new Arithmetic(ADD_INTEGER, ADD, subscript, new Constant(INTEGER, 2L)); return p.project( Assignments.of( // The subscript expression instance is part of both the assignments @@ -428,8 +428,8 @@ public void testPushdownWithDuplicateExpressions() }) .matches(project( ImmutableMap.of( - "expr_deref", expression(new SymbolReference(INTEGER, "struct_of_bigint#a")), - "expr_deref_2", expression(new ArithmeticBinaryExpression(ADD_INTEGER, ADD, new SymbolReference(INTEGER, "struct_of_bigint#a"), new Constant(INTEGER, 2L)))), + "expr_deref", expression(new Reference(INTEGER, "struct_of_bigint#a")), + "expr_deref_2", expression(new Arithmetic(ADD_INTEGER, ADD, new Reference(INTEGER, "struct_of_bigint#a"), new Constant(INTEGER, 2L)))), tableScan( icebergTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, TupleDomain.all(), diff --git a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java index 7d8ca4d8e8d8..b2fa8792ea14 100644 --- a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java +++ b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java @@ -28,9 +28,9 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.ConnectorExpressionTranslator; import org.junit.jupiter.api.Test; @@ -166,8 +166,8 @@ public void testConvertIsNull() // c_varchar IS NULL ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new IsNullPredicate( - new SymbolReference(VARCHAR, "c_varchar_symbol"))), + new IsNull( + new Reference(VARCHAR, "c_varchar_symbol"))), Map.of("c_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NULL"); @@ -180,7 +180,7 @@ public void testConvertIsNotNull() // c_varchar IS NOT NULL ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_varchar_symbol")))), + new Not(new IsNull(new Reference(VARCHAR, "c_varchar_symbol")))), Map.of("c_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NOT NULL"); @@ -193,8 +193,8 @@ public void testConvertNotExpression() // NOT(expression) ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new NotExpression( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_varchar_symbol"))))), + new Not( + new Not(new IsNull(new Reference(VARCHAR, "c_varchar_symbol"))))), Map.of("c_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("NOT ((`c_varchar`) IS NOT NULL)"); diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java index 9b1716ca021f..b89121454ab8 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java @@ -29,11 +29,11 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; -import io.trino.sql.ir.SubscriptExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -53,8 +53,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; -import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.ir.Arithmetic.Operator.ADD; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -133,7 +133,7 @@ public void testPushdownDisabled() session, any( project( - ImmutableMap.of("expr_1", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new SubscriptExpression(BIGINT, new SymbolReference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_1", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new Subscript(BIGINT, new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), tableScan(tableName, ImmutableMap.of("col0", "col0"))))); } @@ -174,7 +174,7 @@ public void testDereferencePushdown() "SELECT col0.x FROM " + tableName + " WHERE col0.x = col1 + 3 and col0.y = 2", anyTree( filter( - new ComparisonExpression(EQUAL, new SymbolReference(BIGINT, "x"), new ArithmeticBinaryExpression(ADD_BIGINT, ADD, new SymbolReference(BIGINT, "col1"), new Constant(BIGINT, 3L))), + new Comparison(EQUAL, new Reference(BIGINT, "x"), new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "col1"), new Constant(BIGINT, 3L))), tableScan( table -> { MongoTableHandle actualTableHandle = (MongoTableHandle) table; @@ -205,9 +205,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new SubscriptExpression(INTEGER, new SymbolReference(RowType.anonymousRow(INTEGER), "expr_0"), new Constant(INTEGER, 1L))), - "expr_0", expression(new SymbolReference(RowType.anonymousRow(INTEGER), "expr_0")), - "expr_0_y", expression(new SubscriptExpression(INTEGER, new SymbolReference(RowType.anonymousRow(INTEGER), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_x", expression(new Subscript(INTEGER, new Reference(RowType.anonymousRow(INTEGER), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0", expression(new Reference(RowType.anonymousRow(INTEGER), "expr_0")), + "expr_0_y", expression(new Subscript(INTEGER, new Reference(RowType.anonymousRow(INTEGER), "expr_0"), new Constant(INTEGER, 2L)))), PlanMatchPattern.join(INNER, builder -> builder .equiCriteria("t_expr_1", "s_expr_1") .left( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index 7009d5f3922f..e0c229eaf36c 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -37,17 +37,17 @@ import io.trino.spi.expression.Variable; import io.trino.spi.function.OperatorType; import io.trino.spi.session.PropertyMetadata; -import io.trino.sql.ir.ArithmeticBinaryExpression; -import io.trino.sql.ir.ArithmeticNegation; -import io.trino.sql.ir.ComparisonExpression; +import io.trino.sql.ir.Arithmetic; +import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.InPredicate; -import io.trino.sql.ir.IsNullPredicate; -import io.trino.sql.ir.LogicalExpression; -import io.trino.sql.ir.NotExpression; -import io.trino.sql.ir.NullIfExpression; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.In; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Negation; +import io.trino.sql.ir.Not; +import io.trino.sql.ir.NullIf; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.testing.TestingConnectorSession; import org.junit.jupiter.api.Test; @@ -222,11 +222,11 @@ public void testConvertOr() ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new LogicalExpression( - LogicalExpression.Operator.OR, + new Logical( + Logical.Operator.OR, List.of( - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)), - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(BIGINT, "c_bigint_symbol_2"), new Constant(BIGINT, 415L))))), + new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)), + new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol_2"), new Constant(BIGINT, 415L))))), Map.of( "c_bigint_symbol", BIGINT_COLUMN, "c_bigint_symbol_2", BIGINT_COLUMN)) @@ -243,15 +243,15 @@ public void testConvertOrWithAnd() ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new LogicalExpression( - LogicalExpression.Operator.OR, + new Logical( + Logical.Operator.OR, List.of( - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)), - new LogicalExpression( - LogicalExpression.Operator.AND, + new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)), + new Logical( + Logical.Operator.AND, List.of( - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 43L)), - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(BIGINT, "c_bigint_symbol_2"), new Constant(BIGINT, 44L))))))), + new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 43L)), + new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol_2"), new Constant(BIGINT, 44L))))))), Map.of( "c_bigint_symbol", BIGINT_COLUMN, "c_bigint_symbol_2", BIGINT_COLUMN)) @@ -266,11 +266,11 @@ public void testConvertOrWithAnd() @Test public void testConvertComparison() { - for (ComparisonExpression.Operator operator : ComparisonExpression.Operator.values()) { + for (Comparison.Operator operator : Comparison.Operator.values()) { Optional converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new ComparisonExpression(operator, new SymbolReference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L))), + new Comparison(operator, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L))), Map.of("c_bigint_symbol", BIGINT_COLUMN)); switch (operator) { @@ -298,11 +298,11 @@ public void testConvertArithmeticBinary() { TestingFunctionResolution resolver = new TestingFunctionResolution(); - for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) { + for (Arithmetic.Operator operator : Arithmetic.Operator.values()) { ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new ArithmeticBinaryExpression(resolver.resolveOperator( + new Arithmetic(resolver.resolveOperator( switch (operator) { case ADD -> OperatorType.ADD; case SUBTRACT -> OperatorType.SUBTRACT; @@ -310,7 +310,7 @@ public void testConvertArithmeticBinary() case DIVIDE -> OperatorType.DIVIDE; case MODULUS -> OperatorType.MODULUS; }, - ImmutableList.of(BIGINT, BIGINT)), operator, new SymbolReference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L))), + ImmutableList.of(BIGINT, BIGINT)), operator, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L))), Map.of("c_bigint_symbol", BIGINT_COLUMN)) .orElseThrow(); @@ -325,7 +325,7 @@ public void testConvertArithmeticUnaryMinus() ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new ArithmeticNegation(new SymbolReference(BIGINT, "c_bigint_symbol"))), + new Negation(new Reference(BIGINT, "c_bigint_symbol"))), Map.of("c_bigint_symbol", BIGINT_COLUMN)) .orElseThrow(); @@ -339,8 +339,8 @@ public void testConvertIsNull() // c_varchar IS NULL ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new IsNullPredicate( - new SymbolReference(VARCHAR, "c_varchar_symbol"))), + new IsNull( + new Reference(VARCHAR, "c_varchar_symbol"))), Map.of("c_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("(\"c_varchar\") IS NULL"); @@ -353,7 +353,7 @@ public void testConvertIsNotNull() // c_varchar IS NOT NULL ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_varchar_symbol")))), + new Not(new IsNull(new Reference(VARCHAR, "c_varchar_symbol")))), Map.of("c_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("(\"c_varchar\") IS NOT NULL"); @@ -366,9 +366,9 @@ public void testConvertNullIf() // nullif(a_varchar, b_varchar) ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new NullIfExpression( - new SymbolReference(VARCHAR, "a_varchar_symbol"), - new SymbolReference(VARCHAR, "b_varchar_symbol"))), + new NullIf( + new Reference(VARCHAR, "a_varchar_symbol"), + new Reference(VARCHAR, "b_varchar_symbol"))), ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("NULLIF((\"c_varchar\"), (\"c_varchar\"))"); @@ -381,8 +381,8 @@ public void testConvertNotExpression() // NOT(expression) ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, translateToConnectorExpression( - new NotExpression( - new NotExpression(new IsNullPredicate(new SymbolReference(VARCHAR, "c_varchar_symbol"))))), + new Not( + new Not(new IsNull(new Reference(VARCHAR, "c_varchar_symbol"))))), Map.of("c_varchar_symbol", VARCHAR_COLUMN)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("NOT ((\"c_varchar\") IS NOT NULL)"); @@ -395,12 +395,12 @@ public void testConvertIn() ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( - new InPredicate( - new SymbolReference(createVarcharType(10), "c_varchar"), + new In( + new Reference(createVarcharType(10), "c_varchar"), List.of( new Constant(VARCHAR_COLUMN.getColumnType(), utf8Slice("value1")), new Constant(VARCHAR_COLUMN.getColumnType(), utf8Slice("value2")), - new SymbolReference(createVarcharType(10), "c_varchar2")))), + new Reference(createVarcharType(10), "c_varchar2")))), Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN, VARCHAR_COLUMN2.getColumnName(), VARCHAR_COLUMN2)) .orElseThrow(); assertThat(converted.expression()).isEqualTo("(\"c_varchar\") IN (?, ?, \"c_varchar2\")"); diff --git a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/TestThriftProjectionPushdown.java b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/TestThriftProjectionPushdown.java index a16e8d05f40c..fd99cd453cce 100644 --- a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/TestThriftProjectionPushdown.java +++ b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/TestThriftProjectionPushdown.java @@ -28,7 +28,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ir.SymbolReference; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; import io.trino.sql.planner.iterative.rule.PushProjectionIntoTableScan; @@ -184,7 +184,7 @@ public void testProjectionPushdown() ImmutableMap.of(orderStatusSymbol, columnHandle))); }) .matches(project( - ImmutableMap.of("expr_2", expression(new SymbolReference(VARCHAR, columnName))), + ImmutableMap.of("expr_2", expression(new Reference(VARCHAR, columnName))), tableScan( projectedThriftHandle::equals, TupleDomain.all(), @@ -217,7 +217,7 @@ public void testPruneColumns() .buildOrThrow())); }) .matches(project( - ImmutableMap.of("expr", expression(new SymbolReference(BIGINT, nationKeyColumn.getColumnName()))), + ImmutableMap.of("expr", expression(new Reference(BIGINT, nationKeyColumn.getColumnName()))), tableScan( new ThriftTableHandle( TINY_SCHEMA, diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index ff3fda57205d..c4682b6b080b 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -93,7 +93,6 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.planner.planprinter.JsonRenderer.JsonRenderedNode; -import static io.trino.sql.planner.planprinter.NodeRepresentation.TypedSymbol.typedSymbol; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.lang.Double.NaN; import static java.lang.String.format; @@ -1448,14 +1447,14 @@ public void testAnonymizedJsonPlan() "6", "Output", ImmutableMap.of("columnNames", "[column_1]"), - ImmutableList.of(typedSymbol("symbol_1", DOUBLE)), + ImmutableList.of(new Symbol(DOUBLE, "symbol_1")), ImmutableList.of(), ImmutableList.of(new PlanNodeStatsAndCostSummary(10., 90., 0., 0., 0.)), ImmutableList.of(new JsonRenderedNode( "100", "Limit", ImmutableMap.of("count", "10", "withTies", "", "inputPreSortedBy", "[]"), - ImmutableList.of(typedSymbol("symbol_1", DOUBLE)), + ImmutableList.of(new Symbol(DOUBLE, "symbol_1")), ImmutableList.of(), ImmutableList.of(new PlanNodeStatsAndCostSummary(10., 90., 90., 0., 0.)), ImmutableList.of(new JsonRenderedNode( @@ -1466,14 +1465,14 @@ public void testAnonymizedJsonPlan() "isReplicateNullsAndAny", "", "hashColumn", "[]", "arguments", "[]"), - ImmutableList.of(typedSymbol("symbol_1", DOUBLE)), + ImmutableList.of(new Symbol(DOUBLE, "symbol_1")), ImmutableList.of(), ImmutableList.of(new PlanNodeStatsAndCostSummary(10., 90., 0., 0., 0.)), ImmutableList.of(new JsonRenderedNode( "140", "RemoteSource", ImmutableMap.of("sourceFragmentIds", "[1]"), - ImmutableList.of(typedSymbol("symbol_1", DOUBLE)), + ImmutableList.of(new Symbol(DOUBLE, "symbol_1")), ImmutableList.of(), ImmutableList.of(), ImmutableList.of()))))))), @@ -1484,7 +1483,7 @@ public void testAnonymizedJsonPlan() "count", "10", "withTies", "", "inputPreSortedBy", "[]"), - ImmutableList.of(typedSymbol("symbol_1", DOUBLE)), + ImmutableList.of(new Symbol(DOUBLE, "symbol_1")), ImmutableList.of(), ImmutableList.of(new PlanNodeStatsAndCostSummary(10., 90., 90., 0., 0.)), ImmutableList.of(new JsonRenderedNode( @@ -1492,7 +1491,7 @@ public void testAnonymizedJsonPlan() "TableScan", ImmutableMap.of( "table", "[table = catalog_1.schema_1.table_1, connector = tpch]"), - ImmutableList.of(typedSymbol("symbol_1", DOUBLE)), + ImmutableList.of(new Symbol(DOUBLE, "symbol_1")), ImmutableList.of("symbol_1 := column_2"), ImmutableList.of(new PlanNodeStatsAndCostSummary(NaN, NaN, NaN, 0., 0.)), ImmutableList.of()))));