diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java index 4b8b20d4fc0..cf0951283c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java @@ -53,10 +53,9 @@ protected Void visitCoalesce(Coalesce node, C context) } @Override - protected Void visitSubscript(Subscript node, C context) + protected Void visitFieldReference(FieldReference node, C context) { process(node.base(), context); - process(node.index(), context); return null; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java b/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java index 4c4860f44b6..6874d85afea 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Expression.java @@ -42,13 +42,13 @@ @JsonSubTypes.Type(value = Row.class, name = "row"), @JsonSubTypes.Type(value = Case.class, name = "case"), @JsonSubTypes.Type(value = Switch.class, name = "switch"), - @JsonSubTypes.Type(value = Subscript.class, name = "subscript"), + @JsonSubTypes.Type(value = FieldReference.class, name = "field"), @JsonSubTypes.Type(value = Reference.class, name = "reference"), }) public sealed interface Expression permits Arithmetic, Between, Bind, Call, Case, Cast, Coalesce, - Comparison, Constant, In, IsNull, Lambda, Logical, Negation, - Not, NullIf, Reference, Row, Subscript, Switch + Comparison, Constant, FieldReference, In, IsNull, Lambda, Logical, + Negation, Not, NullIf, Reference, Row, Switch { Type type(); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java index 29e4fe325be..566b91b8be4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java @@ -62,9 +62,9 @@ protected String visitExpression(Expression node, Void context) } @Override - protected String visitSubscript(Subscript node, Void context) + protected String visitFieldReference(FieldReference node, Void context) { - return formatExpression(node.base()) + "[" + formatExpression(node.index()) + "]"; + return formatExpression(node.base()) + "." + node.field(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java index 282c70525c3..3d5cbb662af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionRewriter.java @@ -105,7 +105,7 @@ public Expression rewriteConstant(Constant node, C context, ExpressionTreeRewrit return rewriteExpression(node, context, treeRewriter); } - public Expression rewriteSubscript(Subscript node, C context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteSubscript(FieldReference node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java index bf5dcaab989..a8e8f4eceaa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java @@ -136,7 +136,7 @@ public Expression visitArithmetic(Arithmetic node, Context context) } @Override - protected Expression visitSubscript(Subscript node, Context context) + protected Expression visitFieldReference(FieldReference node, Context context) { if (!context.isDefaultRewrite()) { Expression result = rewriter.rewriteSubscript(node, context.get(), ExpressionTreeRewriter.this); @@ -146,10 +146,9 @@ protected Expression visitSubscript(Subscript node, Context context) } Expression base = rewrite(node.base(), context.get()); - Expression index = rewrite(node.index(), context.get()); - if (base != node.base() || index != node.index()) { - return new Subscript(base, index); + if (base != node.base()) { + return new FieldReference(base, node.field()); } return node; diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java b/core/trino-main/src/main/java/io/trino/sql/ir/FieldReference.java similarity index 77% rename from core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java rename to core/trino-main/src/main/java/io/trino/sql/ir/FieldReference.java index d12b0b58ad7..376843d35c7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/FieldReference.java @@ -21,40 +21,35 @@ import java.util.List; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.sql.ir.IrUtils.validateType; @JsonSerialize -public record Subscript(Expression base, Expression index) +public record FieldReference(Expression base, int field) implements Expression { - public Subscript + public FieldReference { if (!(base.type() instanceof RowType rowType)) { throw new IllegalArgumentException("Expected 'row' type but found '%s' for expression: %s".formatted(base.type(), base)); } - validateType(INTEGER, index); - int field = (int) (long) ((Constant) index).value() - 1; checkArgument(field < rowType.getFields().size(), "Expected 'row' type to have at least %s fields, but has: %s", field + 1, rowType.getFields().size()); } @Override public Type type() { - int field = (int) (long) ((Constant) index).value() - 1; return ((RowType) base.type()).getFields().get(field).getType(); } @Override public R accept(IrVisitor visitor, C context) { - return visitor.visitSubscript(this, context); + return visitor.visitFieldReference(this, context); } @Override public List children() { - return ImmutableList.of(base, index); + return ImmutableList.of(base); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java index d6662ebc012..1f6ef21321a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java @@ -102,7 +102,7 @@ protected R visitIsNull(IsNull node, C context) return visitExpression(node, context); } - protected R visitSubscript(Subscript node, C context) + protected R visitFieldReference(FieldReference node, C context) { return visitExpression(node, context); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index e1d9396c69b..e93eceeac7e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -43,6 +43,7 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.IsNull; @@ -51,7 +52,6 @@ import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.tree.QualifiedName; import io.trino.type.JoniRegexp; import io.trino.type.JsonPathType; @@ -94,7 +94,6 @@ import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.DynamicFilters.isDynamicFilterFunction; @@ -208,7 +207,7 @@ public Optional translate(ConnectorExpression expression) if (expression instanceof FieldDereference dereference) { return translate(dereference.getTarget()) - .map(base -> new Subscript(base, new Constant(INTEGER, (long) (dereference.getField() + 1)))); + .map(base -> new FieldReference(base, dereference.getField())); } if (expression instanceof io.trino.spi.expression.Call call) { @@ -793,7 +792,7 @@ protected Optional visitNullIf(NullIf node, Void context) } @Override - protected Optional visitSubscript(Subscript node, Void context) + protected Optional visitFieldReference(FieldReference node, Void context) { if (!(node.base().type() instanceof RowType)) { return Optional.empty(); @@ -804,7 +803,7 @@ protected Optional visitSubscript(Subscript node, Void cont return Optional.empty(); } - return Optional.of(new FieldDereference(((Expression) node).type(), translatedBase.get(), (int) ((long) ((Constant) node.index()).value() - 1))); + return Optional.of(new FieldDereference(((Expression) node).type(), translatedBase.get(), node.field())); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index 3dca9fa3eab..04a21e6b6fe 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -18,7 +18,6 @@ import io.trino.Session; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; -import io.trino.operator.scalar.ArraySubscriptOperator; import io.trino.spi.TrinoException; import io.trino.spi.block.SqlRow; import io.trino.spi.connector.ConnectorSession; @@ -43,6 +42,7 @@ import io.trino.sql.ir.Comparison.Operator; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.IsNull; @@ -53,7 +53,6 @@ import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.type.FunctionType; @@ -82,7 +81,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.block.RowValueBuilder.buildRowValue; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; @@ -95,7 +93,6 @@ import static io.trino.sql.DynamicFilters.isDynamicFilter; import static io.trino.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -883,36 +880,20 @@ protected Object visitRow(Row node, Object context) } @Override - protected Object visitSubscript(Subscript node, Object context) + protected Object visitFieldReference(FieldReference node, Object context) { Object base = processWithExceptionHandling(node.base(), context); if (base == null) { return null; } - Object index = processWithExceptionHandling(node.index(), context); - if (index == null) { - return null; - } - if ((index instanceof Long) && isArray(node.base().type())) { - ArraySubscriptOperator.checkArrayIndex((Long) index); - } - if (hasUnresolvedValue(base, index)) { - return new Subscript(toExpression(base, node.base().type()), toExpression(index, node.index().type())); - } - - // Subscript on Row hasn't got a dedicated operator. It is interpreted by hand. - if (base instanceof SqlRow row) { - int fieldIndex = toIntExact((long) index - 1); - if (fieldIndex < 0 || fieldIndex >= row.getFieldCount()) { - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "ROW index out of bounds: " + (fieldIndex + 1)); - } - Type returnType = node.base().type().getTypeParameters().get(fieldIndex); - return readNativeValue(returnType, row.getRawFieldBlock(fieldIndex), row.getRawIndex()); + if (hasUnresolvedValue(base)) { + return new FieldReference(toExpression(base, node.base().type()), node.field()); } - // Subscript on Array or Map is interpreted using operator. - return invokeOperator(OperatorType.SUBSCRIPT, types(node.base(), node.index()), ImmutableList.of(base, index)); + SqlRow row = (SqlRow) base; + Type returnType = node.base().type().getTypeParameters().get(node.field()); + return readNativeValue(returnType, row.getRawFieldBlock(node.field()), row.getRawIndex()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java index 5504f8c5191..3b2468bb089 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NullabilityAnalyzer.java @@ -19,9 +19,9 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.NullIf; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import java.util.concurrent.atomic.AtomicBoolean; @@ -97,7 +97,7 @@ protected Void visitSwitch(Switch node, AtomicBoolean result) } @Override - protected Void visitSubscript(Subscript node, AtomicBoolean result) + protected Void visitFieldReference(FieldReference node, AtomicBoolean result) { result.set(true); return null; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 09cfb0c9f2c..eba5b3a9fde 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -48,11 +48,11 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Not; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.RelationPlanner.PatternRecognitionComponents; import io.trino.sql.planner.plan.AggregationNode; @@ -81,7 +81,6 @@ import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.Delete; import io.trino.sql.tree.FetchFirst; -import io.trino.sql.tree.FieldReference; import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionCall.NullTreatment; import io.trino.sql.tree.Join; @@ -513,7 +512,7 @@ public PlanNode plan(Delete node) builder = filter(builder, node.getWhere().get(), node); } - FieldReference reference = analysis.getRowIdField(table); + io.trino.sql.tree.FieldReference reference = analysis.getRowIdField(table); Symbol rowIdSymbol = builder.translate(reference); List outputs = ImmutableList.of( symbolAllocator.newSymbol("partialrows", BIGINT), @@ -655,7 +654,7 @@ public PlanNode plan(Update node) } } - FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable()); + io.trino.sql.tree.FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable()); assignments.putIdentity(relationPlan.getFieldMappings().get(rowIdReference.getFieldIndex())); // Add the "present" field @@ -770,7 +769,7 @@ public MergeWriterNode plan(Merge merge) PlanBuilder subPlan = newPlanBuilder(joinPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); - FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable()); + io.trino.sql.tree.FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable()); Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex()); // Build the SearchedCaseExpression that creates the project merge_row @@ -889,7 +888,7 @@ public MergeWriterNode plan(Merge merge) subPlanProject, Assignments.builder() .putIdentities(subPlanProject.getOutputSymbols()) - .put(caseNumberSymbol, new Subscript(mergeRowSymbol.toSymbolReference(), new Constant(INTEGER, (long) mergeAnalysis.getMergeRowType().getFields().size()))) + .put(caseNumberSymbol, new FieldReference(mergeRowSymbol.toSymbolReference(), mergeAnalysis.getMergeRowType().getFields().size() - 1)) .build()); // Mark distinct combinations of the unique_id value and the case_number diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index a640f75167d..3a9d65f8eb4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -47,6 +47,7 @@ import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Lambda; @@ -55,7 +56,6 @@ import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; @@ -79,7 +79,6 @@ import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Extract; -import io.trino.sql.tree.FieldReference; import io.trino.sql.tree.Format; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.GenericLiteral; @@ -138,7 +137,6 @@ import static io.trino.spi.StandardErrorCode.TOO_MANY_ARGUMENTS; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType; import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType; import static io.trino.spi.type.TinyintType.TINYINT; @@ -158,6 +156,7 @@ import static io.trino.util.DateTimeUtils.parseDayTimeInterval; import static io.trino.util.DateTimeUtils.parseYearMonthInterval; import static io.trino.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -274,7 +273,7 @@ public boolean canTranslate(Expression expression) { if (astToSymbols.containsKey(scopeAwareKey(expression, analysis, scope)) || substitutions.containsKey(NodeRef.of(expression)) || - expression instanceof FieldReference) { + expression instanceof io.trino.sql.tree.FieldReference) { return true; } @@ -308,7 +307,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot) } else { result = switch (expr) { - case FieldReference expression -> translate(expression); + case io.trino.sql.tree.FieldReference expression -> translate(expression); case Identifier expression -> translate(expression); case FunctionCall expression -> translate(expression); case DereferenceExpression expression -> translate(expression); @@ -611,7 +610,7 @@ private io.trino.sql.ir.Expression translate(LongLiteral expression) return new Constant(analysis.getType(expression), expression.getParsedValue()); } - private io.trino.sql.ir.Expression translate(FieldReference expression) + private io.trino.sql.ir.Expression translate(io.trino.sql.tree.FieldReference expression) { return getSymbolForColumn(expression) .map(Symbol::toSymbolReference) @@ -670,9 +669,7 @@ private io.trino.sql.ir.Expression translate(DereferenceExpression expression) checkState(index >= 0, "could not find field name: %s", fieldName); - return new Subscript( - translateExpression(expression.getBase()), - new Constant(INTEGER, (long) (index + 1))); + return new FieldReference(translateExpression(expression.getBase()), index); } private io.trino.sql.ir.Expression translate(Array expression) @@ -957,8 +954,8 @@ private io.trino.sql.ir.Expression translate(SubscriptExpression node) // Do not rewrite subscript index into symbol. Row subscript index is required to be a literal. io.trino.sql.ir.Expression rewrittenBase = translateExpression(node.getBase()); LongLiteral index = (LongLiteral) node.getIndex(); - return new Subscript( - rewrittenBase, new Constant(INTEGER, index.getParsedValue())); + return new FieldReference( + rewrittenBase, toIntExact(index.getParsedValue() - 1)); } ResolvedFunction operator = plannerContext.getMetadata() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java index 4aeec100e0f..80d73f883dc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java @@ -17,9 +17,9 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import java.util.Collection; @@ -38,7 +38,7 @@ class DereferencePushdown { private DereferencePushdown() {} - public static Set extractRowSubscripts(Collection expressions, boolean allowOverlap) + public static Set extractRowSubscripts(Collection expressions, boolean allowOverlap) { Set symbolReferencesAndRowSubscripts = expressions.stream() .flatMap(expression -> getSymbolReferencesAndRowSubscripts(expression).stream()) @@ -54,8 +54,8 @@ public static Set extractRowSubscripts(Collection express // Retain row subscript expressions return candidateExpressions.stream() - .filter(Subscript.class::isInstance) - .map(Subscript.class::cast) + .filter(FieldReference.class::isInstance) + .map(FieldReference.class::cast) .collect(toImmutableSet()); } @@ -63,19 +63,19 @@ public static boolean exclusiveDereferences(Set projections) { return projections.stream() .allMatch(expression -> expression instanceof Reference || - (expression instanceof Subscript subscript && - isRowSubscriptChain(subscript) && + (expression instanceof FieldReference fieldReference && + isRowSubscriptChain(fieldReference) && !prefixExists(expression, projections))); } - public static Symbol getBase(Subscript expression) + public static Symbol getBase(FieldReference expression) { return getOnlyElement(extractAll(expression)); } /** - * Extract the sub-expressions of type {@link Subscript} or {@link Reference} from the expression - * in a top-down manner. The expressions within the base of a valid {@link Subscript} sequence are not extracted. + * Extract the sub-expressions of type {@link FieldReference} or {@link Reference} from the expression + * in a top-down manner. The expressions within the base of a valid {@link FieldReference} sequence are not extracted. */ private static List getSymbolReferencesAndRowSubscripts(Expression expression) { @@ -84,7 +84,7 @@ private static List getSymbolReferencesAndRowSubscripts(Expression e new DefaultTraversalVisitor>() { @Override - protected Void visitSubscript(Subscript node, ImmutableList.Builder context) + protected Void visitFieldReference(FieldReference node, ImmutableList.Builder context) { if (isRowSubscriptChain(node)) { context.add(node); @@ -109,21 +109,21 @@ protected Void visitLambda(Lambda node, ImmutableList.Builder contex return builder.build(); } - private static boolean isRowSubscriptChain(Subscript expression) + private static boolean isRowSubscriptChain(FieldReference expression) { if (!(expression.base().type() instanceof RowType)) { return false; } return (expression.base() instanceof Reference) || - ((expression.base() instanceof Subscript subscript) && isRowSubscriptChain(subscript)); + ((expression.base() instanceof FieldReference fieldReference) && isRowSubscriptChain(fieldReference)); } private static boolean prefixExists(Expression expression, Set expressions) { Expression current = expression; - while (current instanceof Subscript subscript) { - current = subscript.base(); + while (current instanceof FieldReference fieldReference) { + current = fieldReference.base(); if (expressions.contains(current)) { return true; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java index aa1b098147a..079b5270040 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; @@ -78,7 +78,7 @@ public Pattern getPattern() @Override public Result apply(FilterNode node, Captures captures, Context context) { - Set dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true); + Set dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true); if (dereferences.isEmpty()) { return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java index d04a5575d0c..29048a8bdb2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java @@ -22,8 +22,8 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.iterative.Rule; @@ -180,8 +180,8 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod // skip dereferences, otherwise, inlining can cause conflicts with PushdownDereferences Expression assignment = child.getAssignments().get(entry.getKey()); - if (assignment instanceof Subscript) { - if (((Subscript) assignment).base().type() instanceof RowType) { + if (assignment instanceof FieldReference) { + if (((FieldReference) assignment).base().type() instanceof RowType) { return false; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java index 5c23f7d0bd8..6e1946e162e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; @@ -80,7 +80,7 @@ public Result apply(ProjectNode node, Captures captures, Rule.Context context) .build(); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(expressions, false); + Set dereferences = extractRowSubscripts(expressions, false); if (dereferences.isEmpty()) { return Result.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java index 63afe746a73..85ca829ebeb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java @@ -20,8 +20,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -92,7 +92,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); joinNode.getFilter().ifPresent(expressionsBuilder::add); - Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); + Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); // Exclude criteria symbols ImmutableSet.Builder criteriaSymbolsBuilder = ImmutableSet.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java index d97bb3cf3bb..076054eada3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java @@ -18,8 +18,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -66,7 +66,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) ProjectNode child = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(node.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(node.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being synthesized within child dereferences = dereferences.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java index f212b553df9..0b5792fd1f3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.PlanNode; @@ -78,7 +78,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) SemiJoinNode semiJoinNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // All dereferences can be assumed on the symbols coming from source, since filteringSource output is not propagated, // and semiJoinOutput is of type boolean. We exclude pushdown of dereferences on sourceJoinSymbol. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java index 3f0f0bbe764..1e8219250e5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; @@ -79,7 +79,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); // Extract dereferences for pushdown - Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); + Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false); // Only retain dereferences on replicate symbols dereferences = dereferences.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java index 33df63435a9..c40897f48c0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; @@ -70,7 +70,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) AssignUniqueId assignUniqueId = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // We do not need to filter dereferences on idColumn symbol since it is supposed to be of BIGINT type. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java index 80486469689..e90563ad4e9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java @@ -20,8 +20,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) LimitNode limitNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in tiesResolvingScheme and requiresPreSortedInputs Set excludedSymbols = ImmutableSet.builder() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java index 6da535d1779..5d064f74065 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.MarkDistinctNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) MarkDistinctNode markDistinctNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on distinct symbols being used in markDistinctNode. We do not need to filter // dereferences on markerSymbol since it is supposed to be of boolean type. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java index 6b9e01dd122..6ca0e9f576f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) RowNumberNode rowNumberNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in partitionBy dereferences = dereferences.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java index 26a90cd005b..3fb1e2150ee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) SortNode sortNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols used in ordering scheme to avoid replication of data dereferences = dereferences.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java index b724bd7a794..c36f09f39e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; @@ -75,7 +75,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) TopNNode topNNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in orderBy dereferences = dereferences.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java index eebc12f7375..616fc19c1f3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -78,7 +78,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) TopNRankingNode topNRankingNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false); // Exclude dereferences on symbols being used in partitionBy and orderBy DataOrganizationSpecification specification = topNRankingNode.getSpecification(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java index 92d1861abe0..8b79a2103a0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java @@ -19,8 +19,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -79,7 +79,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) WindowNode windowNode = captures.get(CHILD); // Extract dereferences for pushdown - Set dereferences = extractRowSubscripts( + Set dereferences = extractRowSubscripts( ImmutableList.builder() .addAll(projectNode.getAssignments().getExpressions()) // also include dereference projections used in window functions diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java index 9648d9b33e5..b1994080f3d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapRowSubscript.java @@ -16,11 +16,10 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.ir.Cast; -import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.type.UnknownType; import java.util.ArrayDeque; @@ -48,7 +47,7 @@ private static class Rewriter extends io.trino.sql.ir.ExpressionRewriter { @Override - public Expression rewriteSubscript(Subscript node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteSubscript(FieldReference node, Void context, ExpressionTreeRewriter treeRewriter) { Expression base = treeRewriter.rewrite(node.base(), context); @@ -58,8 +57,7 @@ public Expression rewriteSubscript(Subscript node, Void context, ExpressionTreeR break; } - int index = (int) (long) ((Constant) node.index()).value(); - Type type = rowType.getFields().get(index - 1).getType(); + Type type = rowType.getFields().get(node.field()).getType(); if (!(type instanceof UnknownType)) { coercions.push(new Coercion(type, cast.safe())); } @@ -68,8 +66,7 @@ public Expression rewriteSubscript(Subscript node, Void context, ExpressionTreeR } if (base instanceof Row row) { - int index = (int) (long) ((Constant) node.index()).value(); - Expression result = row.items().get(index - 1); + Expression result = row.items().get(node.field()); while (!coercions.isEmpty()) { Coercion coercion = coercions.pop(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java index 7792fe6cf18..0ed57fc82f9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java @@ -18,9 +18,8 @@ import io.trino.matching.Pattern; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.Subscript; +import io.trino.sql.ir.FieldReference; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.ApplyNode; @@ -32,7 +31,6 @@ import java.util.Optional; import java.util.function.BiFunction; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.planner.plan.Patterns.applyNode; import static java.util.Objects.requireNonNull; @@ -148,8 +146,8 @@ private Optional unwrapSingleColumnRow(Context context, Expression v Symbol valueSymbol = context.getSymbolAllocator().newSymbol("input", elementType); Symbol listSymbol = context.getSymbolAllocator().newSymbol("subquery", elementType); - Assignment inputAssignment = new Assignment(valueSymbol, new Subscript(value, new Constant(INTEGER, 1L))); - Assignment nestedPlanAssignment = new Assignment(listSymbol, new Subscript(list, new Constant(INTEGER, 1L))); + Assignment inputAssignment = new Assignment(valueSymbol, new FieldReference(value, 0)); + Assignment nestedPlanAssignment = new Assignment(listSymbol, new FieldReference(list, 0)); ApplyNode.SetExpression comparison = function.apply(valueSymbol, listSymbol); return Optional.of(new Unwrapping(comparison, inputAssignment, nestedPlanAssignment)); diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index d82ddf5e70c..5e3da5f41a3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -18,7 +18,6 @@ import io.trino.metadata.FunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; -import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.sql.ir.Arithmetic; @@ -32,6 +31,7 @@ import io.trino.sql.ir.Comparison.Operator; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.IsNull; @@ -42,7 +42,6 @@ import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.Symbol; @@ -60,7 +59,6 @@ import static io.trino.spi.function.OperatorType.INDETERMINATE; import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.trino.spi.function.OperatorType.NEGATION; -import static io.trino.spi.function.OperatorType.SUBSCRIPT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -489,20 +487,10 @@ protected RowExpression visitBetween(Between node, Void context) } @Override - protected RowExpression visitSubscript(Subscript node, Void context) + protected RowExpression visitFieldReference(FieldReference node, Void context) { RowExpression base = process(node.base(), context); - RowExpression index = process(node.index(), context); - - if (node.base().type() instanceof RowType) { - long value = (Long) ((ConstantExpression) index).getValue(); - return new SpecialForm(DEREFERENCE, ((Expression) node).type(), base, constant(value - 1, INTEGER)); - } - - return call( - metadata.resolveOperator(SUBSCRIPT, ImmutableList.of(base.getType(), index.getType())), - base, - index); + return new SpecialForm(DEREFERENCE, node.type(), base, constant((long) node.field(), INTEGER)); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index 6cfd0daf084..354edd175be 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -28,6 +28,7 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; @@ -36,7 +37,6 @@ import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.IrExpressionInterpreter; @@ -859,12 +859,12 @@ public void testOptimizeDivideByZero() public void testRowSubscript() { assertOptimizedEquals( - new Subscript(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")), TRUE)), new Constant(INTEGER, 3L)), + new FieldReference(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")), TRUE)), 2), TRUE); assertOptimizedEquals( - new Subscript( - new Subscript( - new Subscript( + new FieldReference( + new FieldReference( + new FieldReference( new Row(ImmutableList.of( new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")), @@ -872,24 +872,24 @@ public void testRowSubscript() new Constant(INTEGER, 2L), new Constant(VARCHAR, Slices.utf8Slice("b")), new Row(ImmutableList.of(new Constant(INTEGER, 3L), new Constant(VARCHAR, Slices.utf8Slice("c")))))))), - new Constant(INTEGER, 3L)), - new Constant(INTEGER, 3L)), - new Constant(INTEGER, 2L)), + 2), + 2), + 1), new Constant(VARCHAR, Slices.utf8Slice("c"))); assertOptimizedEquals( - new Subscript(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(UNKNOWN, null))), new Constant(INTEGER, 2L)), + new FieldReference(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(UNKNOWN, null))), 1), new Constant(UNKNOWN, null)); assertOptimizedEquals( - new Subscript(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)), - new Subscript(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L))); + new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 0), + new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 0)); assertOptimizedEquals( - new Subscript(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)), - new Subscript(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L))); + new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1), + new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1)); - assertTrinoExceptionThrownBy(() -> evaluate(new Subscript(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1))) .hasErrorCode(DIVISION_BY_ZERO); - assertTrinoExceptionThrownBy(() -> evaluate(new Subscript(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), new Constant(INTEGER, 2L)))) + assertTrinoExceptionThrownBy(() -> evaluate(new FieldReference(new Row(ImmutableList.of(new Arithmetic(DIVIDE_INTEGER, DIVIDE, new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), 1))) .hasErrorCode(DIVISION_BY_ZERO); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index ee3fd1b01b1..f19b70f1034 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -38,6 +38,7 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; @@ -45,7 +46,6 @@ import io.trino.sql.ir.Not; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.testing.TestingSession; import io.trino.transaction.TestingTransactionManager; import io.trino.transaction.TransactionManager; @@ -144,9 +144,9 @@ public void testTranslateSymbol() public void testTranslateRowSubscript() { assertTranslationRoundTrips( - new Subscript( + new FieldReference( new Reference(ROW_TYPE, "row_symbol_1"), - new Constant(INTEGER, 1L)), + 0), new FieldDereference( INTEGER, new Variable("row_symbol_1", ROW_TYPE), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java index 8e661dd6bb3..d22b5794b70 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java @@ -24,9 +24,9 @@ import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePlanTest; import org.junit.jupiter.api.Test; @@ -34,7 +34,6 @@ import static io.trino.SystemSessionProperties.MERGE_PROJECT_WITH_VALUES; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.Arithmetic.Operator.MULTIPLY; import static io.trino.sql.ir.Comparison.Operator.EQUAL; @@ -68,7 +67,7 @@ public void testDereferencePushdownMultiLevel() output(ImmutableList.of("a_msg_x", "a_msg", "b_msg_y"), strictProject( ImmutableMap.of( - "a_msg_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg"), new Constant(INTEGER, 1L))), + "a_msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg"), 0)), "a_msg", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg")), "b_msg_y", expression(new Reference(DOUBLE, "b_msg_y"))), join(INNER, builder -> builder @@ -191,8 +190,8 @@ public void testDereferencePushdownWindow() "msg1", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg1")), // not pushed down because used in partition by "msg2", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg2")), // not pushed down because used in order by "msg3", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg3")), // not pushed down because used in window function - "msg4_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg4"), new Constant(INTEGER, 1L))), // pushed down because msg4.x used in window function - "msg5_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg5"), new Constant(INTEGER, 1L)))), // pushed down because window node does not refer it + "msg4_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg4"), 0)), // pushed down because msg4.x used in window function + "msg5_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg5"), 0))), // pushed down because window node does not refer it values("msg1", "msg2", "msg3", "msg4", "msg5")))); } @@ -210,7 +209,7 @@ public void testDereferencePushdownSemiJoin() anyTree( semiJoin("a_x", "b_z", "semi_join_symbol", project( - ImmutableMap.of("a_y", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, DOUBLE, BIGINT), "msg"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("a_y", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE, BIGINT), "msg"), 1))), values(ImmutableList.of("msg", "a_x"), ImmutableList.of())), values(ImmutableList.of("b_z"), ImmutableList.of())))); } @@ -223,7 +222,7 @@ public void testDereferencePushdownLimit() anyTree( strictProject(ImmutableMap.of("x_into_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "msg_x"), new Constant(BIGINT, 3L)))), limit(1, - strictProject(ImmutableMap.of("msg_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg"), new Constant(INTEGER, 1L)))), + strictProject(ImmutableMap.of("msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "msg"), 0))), values("msg")))))); // dereference pushdown + constant folding diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 4d023b943b3..d16bbcc2ae9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -36,13 +36,13 @@ import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; @@ -299,8 +299,8 @@ public void testAllFieldsDereferenceOnSubquery() any( project( ImmutableMap.of( - "output_1", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "row"), new Constant(INTEGER, 1L))), - "output_2", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "row"), new Constant(INTEGER, 2L)))), + "output_1", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "row"), 0)), + "output_2", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "row"), 1))), project( ImmutableMap.of("row", expression(new Row(ImmutableList.of(new Reference(BIGINT, "min"), new Reference(BIGINT, "max"))))), aggregation( @@ -328,8 +328,8 @@ public void testAllFieldsDereferenceFromNonDeterministic() any( project( ImmutableMap.of( - "output_1", expression(new Subscript(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 1L))), - "output_2", expression(new Subscript(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 2L)))), + "output_1", expression(new FieldReference(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), 0)), + "output_2", expression(new FieldReference(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), 1))), project( ImmutableMap.of("row", expression(new Row(ImmutableList.of(new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand"))))), values( @@ -340,8 +340,8 @@ public void testAllFieldsDereferenceFromNonDeterministic() any( project( ImmutableMap.of( - "output_1", expression(new Subscript(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), new Constant(INTEGER, 1L))), - "output_2", expression(new Subscript(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), new Constant(INTEGER, 2L)))), + "output_1", expression(new FieldReference(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), 0)), + "output_2", expression(new FieldReference(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "r"), 1))), values( ImmutableList.of("r"), ImmutableList.of(ImmutableList.of(new Row(ImmutableList.of(randomFunction, randomFunction)))))))); @@ -353,8 +353,8 @@ public void testAllFieldsDereferenceFromNonDeterministic() any( project( ImmutableMap.of( - "output_1", expression(new Subscript(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 1L))), - "output_2", expression(new Subscript(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), new Constant(INTEGER, 2L)))), + "output_1", expression(new FieldReference(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), 0)), + "output_2", expression(new FieldReference(new Reference(RowType.anonymousRow(DOUBLE, DOUBLE), "row"), 1))), values( ImmutableList.of("row"), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 0dc5bf39728..63ee1ed3ccc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -26,9 +26,9 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.NodeRef; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.transaction.TransactionId; import org.junit.jupiter.api.Test; @@ -58,8 +58,8 @@ public class TestPartialTranslator public void testPartialTranslator() { Expression rowSymbolReference = new Reference(RowType.anonymousRow(INTEGER, INTEGER), "row_symbol_1"); - Expression dereferenceExpression1 = new Subscript(rowSymbolReference, new Constant(INTEGER, 1L)); - Expression dereferenceExpression2 = new Subscript(rowSymbolReference, new Constant(INTEGER, 2L)); + Expression dereferenceExpression1 = new FieldReference(rowSymbolReference, 0); + Expression dereferenceExpression2 = new FieldReference(rowSymbolReference, 1); Expression stringLiteral = new Constant(VARCHAR, Slices.utf8Slice("abcd")); Expression symbolReference1 = new Reference(INTEGER, "double_symbol_1"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java index bda8b37da65..58cca84fbb4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java @@ -22,6 +22,7 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.IsNull; @@ -31,7 +32,6 @@ import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.ir.Switch; import io.trino.sql.ir.WhenClause; @@ -325,13 +325,13 @@ protected Boolean visitRow(Row actual, Expression expectedExpression) } @Override - protected Boolean visitSubscript(Subscript actual, Expression expectedExpression) + protected Boolean visitFieldReference(FieldReference actual, Expression expectedExpression) { - if (!(expectedExpression instanceof Subscript expected)) { + if (!(expectedExpression instanceof FieldReference expected)) { return false; } - return process(actual.base(), expected.base()) && process(actual.index(), expected.index()); + return process(actual.base(), expected.base()) && actual.field() == expected.field(); } private boolean process(List actuals, List expecteds) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java index c616f03a1b5..35b3a6b7fa5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java @@ -22,8 +22,8 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -75,7 +75,7 @@ public void test() .put(p.symbol("complex"), new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "x"), new Constant(INTEGER, 2L))) .put(p.symbol("literal"), new Constant(INTEGER, 1L)) .put(p.symbol("complex_2"), new Arithmetic(SUBTRACT_INTEGER, SUBTRACT, new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))) - .put(p.symbol("z"), new Subscript(new Reference(MSG_TYPE, "msg"), new Constant(INTEGER, 1L))) + .put(p.symbol("z"), new FieldReference(new Reference(MSG_TYPE, "msg"), 0)) .put(p.symbol("v"), new Reference(INTEGER, "x")) .build(), p.values(p.symbol("x", INTEGER), p.symbol("msg", MSG_TYPE))))) @@ -95,7 +95,7 @@ public void test() ImmutableMap.of( "x", PlanMatchPattern.expression(new Reference(INTEGER, "x")), "y", PlanMatchPattern.expression(new Arithmetic(MULTIPLY_INTEGER, MULTIPLY, new Reference(INTEGER, "x"), new Constant(INTEGER, 2L))), - "z", PlanMatchPattern.expression(new Subscript(new Reference(MSG_TYPE, "msg"), new Constant(INTEGER, 1L)))), + "z", PlanMatchPattern.expression(new FieldReference(new Reference(MSG_TYPE, "msg"), 0))), values(ImmutableMap.of("x", 0, "msg", 1))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java index 80e008db012..ef64e5e33b1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java @@ -18,8 +18,8 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -65,8 +65,8 @@ public void test() // expression nested in another unrelated expression test( - new Subscript(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L))), anonymousRow(BIGINT)), new Constant(INTEGER, 1L)), - new Subscript(new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT))), new Constant(INTEGER, 1L))); + new FieldReference(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L))), anonymousRow(BIGINT)), 0), + new FieldReference(new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT))), 0)); // don't insert CAST(x AS unknown) test( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java index daefd4745f9..8746130ee67 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -30,12 +30,12 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; @@ -52,7 +52,6 @@ import static io.trino.spi.connector.SortOrder.ASC_NULLS_FIRST; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -115,8 +114,8 @@ public void testDoesNotFire() .on(p -> p.project( Assignments.of( - p.symbol("expr_1"), new Subscript(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), new Constant(INTEGER, 1L)), - p.symbol("expr_2"), new Subscript(new Subscript(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L))), + p.symbol("expr_1"), new FieldReference(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), 0), + p.symbol("expr_2"), new FieldReference(new FieldReference(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), 0), 1)), p.project( Assignments.of( p.symbol("a", ROW_TYPE), new Reference(ROW_TYPE, "a"), @@ -130,7 +129,7 @@ public void testDoesNotFire() p.project( Assignments.of( p.symbol("expr", ROW_TYPE), new Reference(ROW_TYPE, "a"), - p.symbol("a_x"), new Subscript(new Reference(ROW_TYPE, "a"), new Constant(INTEGER, 1L))), + p.symbol("a_x"), new FieldReference(new Reference(ROW_TYPE, "a"), 0)), p.project( Assignments.of(p.symbol("a", ROW_TYPE), new Reference(ROW_TYPE, "a")), p.values(p.symbol("a", ROW_TYPE))))) @@ -143,7 +142,7 @@ public void testPushdownDereferenceThroughProject() tester().assertThat(new PushDownDereferenceThroughProject()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("x"), new FieldReference(new Reference(ROW_TYPE, "msg"), 0)), p.project( Assignments.of( p.symbol("y"), new Reference(BIGINT, "y"), @@ -154,7 +153,7 @@ public void testPushdownDereferenceThroughProject() ImmutableMap.of("x", expression(new Reference(BIGINT, "msg_x"))), strictProject( ImmutableMap.of( - "msg_x", expression(new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + "msg_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg"), 0)), "y", expression(new Reference(BIGINT, "y")), "msg", expression(new Reference(BIGINT, "msg"))), values("msg", "y")))); @@ -167,8 +166,8 @@ public void testPushDownDereferenceThroughJoin() .on(p -> p.project( Assignments.builder() - .put(p.symbol("left_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("right_y"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) + .put(p.symbol("left_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("right_y"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)) .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.join(INNER, @@ -185,14 +184,14 @@ public void testPushDownDereferenceThroughJoin() .left( strictProject( ImmutableMap.of( - "x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), + "x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)), "msg1", expression(new Reference(ROW_TYPE, "msg1")), "unreferenced_symbol", expression(new Reference(BIGINT, "unreferenced_symbol"))), values("msg1", "unreferenced_symbol"))) .right( strictProject( ImmutableMap.builder() - .put("y", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) + .put("y", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 1))) .put("z", expression(new Reference(BIGINT, "z"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), @@ -203,23 +202,23 @@ public void testPushDownDereferenceThroughJoin() .on(p -> p.project( Assignments.of( - p.symbol("expr"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)), + p.symbol("expr"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0), p.symbol("expr_2"), new Reference(ROW_TYPE, "msg2")), p.join(INNER, p.values(p.symbol("msg1", ROW_TYPE)), p.values(p.symbol("msg2", ROW_TYPE)), - new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))), new Constant(BIGINT, 10L))))) + new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new FieldReference(new Reference(ROW_TYPE, "msg1"), 0), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)), new Constant(BIGINT, 10L))))) .matches( project( ImmutableMap.of( "expr", expression(new Reference(BIGINT, "msg1_x")), "expr_2", expression(new Reference(ROW_TYPE, "msg2"))), join(INNER, builder -> builder - .filter(new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))), new Constant(BIGINT, 10L))) + .filter(new Comparison(GREATER_THAN, new Arithmetic(ADD_BIGINT, ADD, new Reference(BIGINT, "msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)), new Constant(BIGINT, 10L))) .left( strictProject( ImmutableMap.of( - "msg1_x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), + "msg1_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)), "msg1", expression(new Reference(ROW_TYPE, "msg1"))), values("msg1"))) .right(values("msg2"))))); @@ -232,8 +231,8 @@ public void testPushdownDereferencesThroughSemiJoin() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)) .build(), p.semiJoin( p.symbol("msg2", ROW_TYPE), @@ -247,7 +246,7 @@ public void testPushdownDereferencesThroughSemiJoin() strictProject( ImmutableMap.builder() .put("msg1_x", expression(new Reference(BIGINT, "expr"))) - .put("msg2_x", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) // Not pushed down because msg2 is sourceJoinSymbol + .put("msg2_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))) // Not pushed down because msg2 is sourceJoinSymbol .buildOrThrow(), semiJoin( "msg2", @@ -255,7 +254,7 @@ public void testPushdownDereferencesThroughSemiJoin() "match", strictProject( ImmutableMap.of( - "expr", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))), + "expr", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)), "msg1", expression(new Reference(ROW_TYPE, "msg1")), "msg2", expression(new Reference(ROW_TYPE, "msg2"))), values("msg1", "msg2")), @@ -269,7 +268,7 @@ public void testPushdownDereferencesThroughUnnest() tester().assertThat(new PushDownDereferenceThroughUnnest()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("x"), new FieldReference(new Reference(ROW_TYPE, "msg"), 0)), p.unnest( ImmutableList.of(p.symbol("msg", ROW_TYPE)), ImmutableList.of(new UnnestNode.Mapping(p.symbol("arr", arrayType), ImmutableList.of(p.symbol("field")))), @@ -282,7 +281,7 @@ public void testPushdownDereferencesThroughUnnest() unnest( strictProject( ImmutableMap.of( - "msg_x", expression(new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + "msg_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg"), 0)), "msg", expression(new Reference(ROW_TYPE, "msg")), "arr", expression(new Reference(arrayType, "arr"))), values("msg", "arr"))))); @@ -296,7 +295,7 @@ public void testPushdownDereferencesThroughUnnest() .on(p -> p.project( Assignments.of( - p.symbol("deref_replicate", BIGINT), new Subscript(new Reference(rowType, "replicate"), new Constant(INTEGER, 2L)), + p.symbol("deref_replicate", BIGINT), new FieldReference(new Reference(rowType, "replicate"), 1), p.symbol("deref_unnest", BIGINT), new Call(subscript, ImmutableList.of(new Reference(nestedColumnType, "unnested_row"), new Constant(BIGINT, 2L)))), p.unnest( ImmutableList.of(p.symbol("replicate", rowType)), @@ -315,7 +314,7 @@ public void testPushdownDereferencesThroughUnnest() ImmutableList.of(unnestMapping("nested", ImmutableList.of("unnested_bigint", "unnested_row"))), strictProject( ImmutableMap.of( - "symbol", expression(new Subscript(new Reference(rowType, "replicate"), new Constant(INTEGER, 2L))), + "symbol", expression(new FieldReference(new Reference(rowType, "replicate"), 1)), "replicate", expression(new Reference(rowType, "replicate")), "nested", expression(new Reference(nestedColumnType, "nested"))), values("replicate", "nested"))))); @@ -334,9 +333,9 @@ public void testExtractDereferencesFromFilterAboveScan() .on(p -> p.filter( new Logical(AND, ImmutableList.of( - new Comparison(NOT_EQUAL, new Subscript(new Subscript(new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new Constant(BIGINT, 5L)), - new Comparison(EQUAL, new Subscript(new Reference(ROW_TYPE, "b"), new Constant(INTEGER, 2L)), new Constant(BIGINT, 2L)), - new Not(new IsNull(new Cast(new Subscript(new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L)), JSON))))), + new Comparison(NOT_EQUAL, new FieldReference(new FieldReference(new Reference(nestedRowType, "a"), 0), 0), new Constant(BIGINT, 5L)), + new Comparison(EQUAL, new FieldReference(new Reference(ROW_TYPE, "b"), 1), new Constant(BIGINT, 2L)), + new Not(new IsNull(new Cast(new FieldReference(new Reference(nestedRowType, "a"), 0), JSON))))), p.tableScan( testTable, ImmutableList.of(p.symbol("a", nestedRowType), p.symbol("b", ROW_TYPE)), @@ -348,9 +347,9 @@ public void testExtractDereferencesFromFilterAboveScan() new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(BIGINT, "expr"), new Constant(BIGINT, 5L)), new Comparison(EQUAL, new Reference(BIGINT, "expr_0"), new Constant(BIGINT, 2L)), new Not(new IsNull(new Cast(new Reference(ROW_TYPE, "expr_1"), JSON))))), strictProject( ImmutableMap.of( - "expr", expression(new Subscript(new Subscript(new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), - "expr_0", expression(new Subscript(new Reference(ROW_TYPE, "b"), new Constant(INTEGER, 2L))), - "expr_1", expression(new Subscript(new Reference(nestedRowType, "a"), new Constant(INTEGER, 1L))), + "expr", expression(new FieldReference(new FieldReference(new Reference(nestedRowType, "a"), 0), 0)), + "expr_0", expression(new FieldReference(new Reference(ROW_TYPE, "b"), 1)), + "expr_1", expression(new FieldReference(new Reference(nestedRowType, "a"), 0)), "a", expression(new Reference(nestedRowType, "a")), "b", expression(new Reference(ROW_TYPE, "b"))), tableScan( @@ -368,21 +367,21 @@ public void testPushdownDereferenceThroughFilter() .on(p -> p.project( Assignments.of( - p.symbol("expr", BIGINT), new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)), - p.symbol("expr_2", BIGINT), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))), + p.symbol("expr", BIGINT), new FieldReference(new Reference(ROW_TYPE, "msg"), 0), + p.symbol("expr_2", BIGINT), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)), p.filter( - new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)), new Constant(BIGINT, 3L)), new Not(new IsNull(new Reference(ROW_TYPE, "msg2"))))), + new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new FieldReference(new Reference(ROW_TYPE, "msg"), 0), new Constant(BIGINT, 3L)), new Not(new IsNull(new Reference(ROW_TYPE, "msg2"))))), p.values(p.symbol("msg", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) .matches( strictProject( ImmutableMap.of( "expr", expression(new Reference(BIGINT, "msg_x")), - "expr_2", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))), // not pushed down since predicate contains msg2 reference + "expr_2", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))), // not pushed down since predicate contains msg2 reference filter( new Logical(AND, ImmutableList.of(new Comparison(NOT_EQUAL, new Reference(BIGINT, "msg_x"), new Constant(BIGINT, 3L)), new Not(new IsNull(new Reference(ROW_TYPE, "msg2"))))), strictProject( ImmutableMap.of( - "msg_x", expression(new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))), + "msg_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg"), 0)), "msg", expression(new Reference(ROW_TYPE, "msg")), "msg2", expression(new Reference(ROW_TYPE, "msg2"))), values("msg", "msg2"))))); @@ -395,8 +394,8 @@ public void testPushDownDereferenceThroughLimit() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_y"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_y"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)) .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.limit(10, @@ -406,7 +405,7 @@ public void testPushDownDereferenceThroughLimit() strictProject( ImmutableMap.builder() .put("msg1_x", expression(new Reference(BIGINT, "x"))) - .put("msg2_y", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) + .put("msg2_y", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 1))) .put("z", expression(new Reference(BIGINT, "z"))) .buildOrThrow(), limit( @@ -414,7 +413,7 @@ public void testPushDownDereferenceThroughLimit() ImmutableList.of(sort("msg2", ASCENDING, FIRST)), strictProject( ImmutableMap.builder() - .put("x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) .put("z", expression(new Reference(BIGINT, "z"))) .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) @@ -428,8 +427,8 @@ public void testPushDownDereferenceThroughLimitWithPreSortedInputs() tester().assertThat(new PushDownDereferencesThroughLimit()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_y"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_y"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 1)) .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.limit( @@ -441,7 +440,7 @@ public void testPushDownDereferenceThroughLimitWithPreSortedInputs() strictProject( ImmutableMap.builder() .put("msg1_x", expression(new Reference(BIGINT, "x"))) - .put("msg2_y", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 2L)))) + .put("msg2_y", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 1))) .put("z", expression(new Reference(BIGINT, "z"))) .buildOrThrow(), limit( @@ -451,7 +450,7 @@ public void testPushDownDereferenceThroughLimitWithPreSortedInputs() ImmutableList.of("msg2"), strictProject( ImmutableMap.builder() - .put("x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) .put("z", expression(new Reference(BIGINT, "z"))) .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) @@ -467,8 +466,8 @@ public void testPushDownDereferenceThroughSort() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg_x"), new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg_y"), new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 2L))) + .put(p.symbol("msg_x"), new FieldReference(new Reference(ROW_TYPE, "msg"), 0)) + .put(p.symbol("msg_y"), new FieldReference(new Reference(ROW_TYPE, "msg"), 1)) .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.sort( @@ -480,7 +479,7 @@ public void testPushDownDereferenceThroughSort() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg_x"), new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg_x"), new FieldReference(new Reference(ROW_TYPE, "msg"), 0)) .put(p.symbol("z"), new Reference(BIGINT, "z")) .build(), p.sort( @@ -495,7 +494,7 @@ public void testPushDownDereferenceThroughSort() sort(ImmutableList.of(sort("z", ASCENDING, FIRST)), strictProject( ImmutableMap.builder() - .put("x", expression(new Subscript(new Reference(ROW_TYPE, "msg"), new Constant(INTEGER, 1L)))) + .put("x", expression(new FieldReference(new Reference(ROW_TYPE, "msg"), 0))) .put("z", expression(new Reference(BIGINT, "z"))) .put("msg", expression(new Reference(ROW_TYPE, "msg"))) .buildOrThrow(), @@ -509,8 +508,8 @@ public void testPushdownDereferenceThroughRowNumber() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)) .build(), p.rowNumber( ImmutableList.of(p.symbol("msg1", ROW_TYPE)), @@ -520,7 +519,7 @@ public void testPushdownDereferenceThroughRowNumber() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg1_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) .put("msg2_x", expression(new Reference(BIGINT, "expr"))) .buildOrThrow(), rowNumber( @@ -528,7 +527,7 @@ public void testPushdownDereferenceThroughRowNumber() .partitionBy(ImmutableList.of("msg1")), strictProject( ImmutableMap.builder() - .put("expr", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) + .put("expr", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))) .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), @@ -542,9 +541,9 @@ public void testPushdownDereferenceThroughTopNRanking() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg3_x"), new Subscript(new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)) + .put(p.symbol("msg3_x"), new FieldReference(new Reference(ROW_TYPE, "msg3"), 0)) .build(), p.topNRanking( new DataOrganizationSpecification( @@ -560,15 +559,15 @@ public void testPushdownDereferenceThroughTopNRanking() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) - .put("msg2_x", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) + .put("msg1_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) + .put("msg2_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))) .put("msg3_x", expression(new Reference(BIGINT, "expr"))) .buildOrThrow(), topNRanking( pattern -> pattern.specification(singletonList("msg1"), singletonList("msg2"), ImmutableMap.of("msg2", ASC_NULLS_FIRST)), strictProject( ImmutableMap.builder() - .put("expr", expression(new Subscript(new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L)))) + .put("expr", expression(new FieldReference(new Reference(ROW_TYPE, "msg3"), 0))) .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .put("msg3", expression(new Reference(ROW_TYPE, "msg3"))) @@ -583,21 +582,21 @@ public void testPushdownDereferenceThroughTopN() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)) .build(), p.topN(5, ImmutableList.of(p.symbol("msg1", ROW_TYPE)), p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg1_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) .put("msg2_x", expression(new Reference(BIGINT, "expr"))) .buildOrThrow(), topN(5, ImmutableList.of(sort("msg1", ASCENDING, FIRST)), strictProject( ImmutableMap.builder() - .put("expr", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) + .put("expr", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))) .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) .buildOrThrow(), @@ -611,11 +610,11 @@ public void testPushdownDereferenceThroughWindow() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg3_x"), new Subscript(new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg4_x"), new Subscript(new Reference(ROW_TYPE, "msg4"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg5_x"), new Subscript(new Reference(ROW_TYPE, "msg5"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)) + .put(p.symbol("msg3_x"), new FieldReference(new Reference(ROW_TYPE, "msg3"), 0)) + .put(p.symbol("msg4_x"), new FieldReference(new Reference(ROW_TYPE, "msg4"), 0)) + .put(p.symbol("msg5_x"), new FieldReference(new Reference(ROW_TYPE, "msg5"), 0)) .build(), p.window( new DataOrganizationSpecification( @@ -647,9 +646,9 @@ public void testPushdownDereferenceThroughWindow() .matches( strictProject( ImmutableMap.builder() - .put("msg1_x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) // not pushed down because used in partitionBy - .put("msg2_x", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))) // not pushed down because used in orderBy - .put("msg3_x", expression(new Subscript(new Reference(ROW_TYPE, "msg3"), new Constant(INTEGER, 1L)))) // not pushed down because the whole column is used in windowNode function + .put("msg1_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) // not pushed down because used in partitionBy + .put("msg2_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))) // not pushed down because used in orderBy + .put("msg3_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg3"), 0))) // not pushed down because the whole column is used in windowNode function .put("msg4_x", expression(new Reference(BIGINT, "expr"))) // pushed down because msg4[1] is being used in the function .put("msg5_x", expression(new Reference(BIGINT, "expr2"))) // pushed down because not referenced in windowNode .buildOrThrow(), @@ -664,8 +663,8 @@ public void testPushdownDereferenceThroughWindow() .put("msg3", expression(new Reference(ROW_TYPE, "msg3"))) .put("msg4", expression(new Reference(ROW_TYPE, "msg4"))) .put("msg5", expression(new Reference(ROW_TYPE, "msg5"))) - .put("expr", expression(new Subscript(new Reference(ROW_TYPE, "msg4"), new Constant(INTEGER, 1L)))) - .put("expr2", expression(new Subscript(new Reference(ROW_TYPE, "msg5"), new Constant(INTEGER, 1L)))) + .put("expr", expression(new FieldReference(new Reference(ROW_TYPE, "msg4"), 0))) + .put("expr2", expression(new FieldReference(new Reference(ROW_TYPE, "msg5"), 0))) .buildOrThrow(), values("msg1", "msg2", "msg3", "msg4", "msg5"))))); } @@ -677,7 +676,7 @@ public void testPushdownDereferenceThroughAssignUniqueId() .on(p -> p.project( Assignments.builder() - .put(p.symbol("expr"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) + .put(p.symbol("expr"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) .build(), p.assignUniqueId( p.symbol("unique"), @@ -690,7 +689,7 @@ public void testPushdownDereferenceThroughAssignUniqueId() strictProject( ImmutableMap.builder() .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) - .put("msg1_x", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("msg1_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) .buildOrThrow(), values("msg1"))))); } @@ -702,8 +701,8 @@ public void testPushdownDereferenceThroughMarkDistinct() .on(p -> p.project( Assignments.builder() - .put(p.symbol("msg1_x"), new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L))) - .put(p.symbol("msg2_x"), new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L))) + .put(p.symbol("msg1_x"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0)) + .put(p.symbol("msg2_x"), new FieldReference(new Reference(ROW_TYPE, "msg2"), 0)) .build(), p.markDistinct( p.symbol("is_distinct", BOOLEAN), @@ -713,7 +712,7 @@ public void testPushdownDereferenceThroughMarkDistinct() strictProject( ImmutableMap.of( "msg1_x", expression(new Reference(BIGINT, "expr")), // pushed down - "msg2_x", expression(new Subscript(new Reference(ROW_TYPE, "msg2"), new Constant(INTEGER, 1L)))), // not pushed down because used in markDistinct + "msg2_x", expression(new FieldReference(new Reference(ROW_TYPE, "msg2"), 0))), // not pushed down because used in markDistinct markDistinct( "is_distinct", singletonList("msg2"), @@ -721,7 +720,7 @@ public void testPushdownDereferenceThroughMarkDistinct() ImmutableMap.builder() .put("msg1", expression(new Reference(ROW_TYPE, "msg1"))) .put("msg2", expression(new Reference(ROW_TYPE, "msg2"))) - .put("expr", expression(new Subscript(new Reference(ROW_TYPE, "msg1"), new Constant(INTEGER, 1L)))) + .put("expr", expression(new FieldReference(new Reference(ROW_TYPE, "msg1"), 0))) .buildOrThrow(), values("msg1", "msg2"))))); } @@ -734,7 +733,7 @@ public void testMultiLevelPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_1"), new Subscript(new Reference(complexType, "a"), new Constant(INTEGER, 1L)), + p.symbol("expr_1"), new FieldReference(new Reference(complexType, "a"), 0), p.symbol("expr_2"), new Arithmetic( ADD_BIGINT, ADD, @@ -744,22 +743,16 @@ public void testMultiLevelPushdown() new Arithmetic( ADD_BIGINT, ADD, - new Subscript( - new Subscript( - new Reference(complexType, "a"), - new Constant(INTEGER, 1L)), - new Constant(INTEGER, 1L)), + new FieldReference( + new FieldReference(new Reference(complexType, "a"), 0), + 0), new Constant(BIGINT, 2L)), - new Subscript( - new Subscript( - new Reference(complexType, "b"), - new Constant(INTEGER, 1L)), - new Constant(INTEGER, 1L))), - new Subscript( - new Subscript( - new Reference(complexType, "b"), - new Constant(INTEGER, 1L)), - new Constant(INTEGER, 2L)))), + new FieldReference( + new FieldReference(new Reference(complexType, "b"), 0), + 0)), + new FieldReference( + new FieldReference(new Reference(complexType, "b"), 0), + 1))), p.project( Assignments.identity(ImmutableList.of(p.symbol("a", complexType), p.symbol("b", complexType))), p.values(p.symbol("a", complexType), p.symbol("b", complexType))))) @@ -767,14 +760,14 @@ public void testMultiLevelPushdown() strictProject( ImmutableMap.of( "expr_1", expression(new Reference(complexType.getFields().get(0).getType(), "a_f1")), - "expr_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new Subscript(new Reference(complexType.getFields().get(0).getType(), "a_f1"), new Constant(INTEGER, 1L)), new Constant(BIGINT, 2L)), new Reference(BIGINT, "b_f1_f1")), new Reference(BIGINT, "b_f1_f2")))), + "expr_2", expression(new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new Arithmetic(ADD_BIGINT, ADD, new FieldReference(new Reference(complexType.getFields().get(0).getType(), "a_f1"), 0), new Constant(BIGINT, 2L)), new Reference(BIGINT, "b_f1_f1")), new Reference(BIGINT, "b_f1_f2")))), strictProject( ImmutableMap.of( "a", expression(new Reference(complexType, "a")), "b", expression(new Reference(complexType, "b")), - "a_f1", expression(new Subscript(new Reference(complexType, "a"), new Constant(INTEGER, 1L))), - "b_f1_f1", expression(new Subscript(new Subscript(new Reference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L))), - "b_f1_f2", expression(new Subscript(new Subscript(new Reference(complexType, "b"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)))), + "a_f1", expression(new FieldReference(new Reference(complexType, "a"), 0)), + "b_f1_f1", expression(new FieldReference(new FieldReference(new Reference(complexType, "b"), 0), 0)), + "b_f1_f2", expression(new FieldReference(new FieldReference(new Reference(complexType, "b"), 0), 1))), values("a", "b")))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java index 23aa9e327a0..bcc70b52fbb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -20,9 +20,8 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.RowType; import io.trino.sql.ir.Arithmetic; -import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -31,7 +30,6 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -159,8 +157,8 @@ public void testDoesntPushDownLimitThroughExclusiveDereferences() return p.limit(1, p.project( Assignments.of( - p.symbol("b"), new Subscript(a.toSymbolReference(), new Constant(INTEGER, 1L)), - p.symbol("c"), new Subscript(a.toSymbolReference(), new Constant(INTEGER, 2L))), + p.symbol("b"), new FieldReference(a.toSymbolReference(), 0), + p.symbol("c"), new FieldReference(a.toSymbolReference(), 1)), p.values(a))); }) .doesNotFire(); @@ -222,13 +220,13 @@ public void testPushDownLimitThroughOverlappingDereferences() return p.limit(1, p.project( Assignments.of( - p.symbol("b"), new Subscript(a.toSymbolReference(), new Constant(INTEGER, 1L)), + p.symbol("b"), new FieldReference(a.toSymbolReference(), 0), p.symbol("c", rowType), a.toSymbolReference()), p.values(a))); }) .matches( project( - ImmutableMap.of("b", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new Subscript(new Reference(rowType, "a"), new Constant(INTEGER, 1L))), "c", expression(new Reference(rowType, "a"))), + ImmutableMap.of("b", io.trino.sql.planner.assertions.PlanMatchPattern.expression(new FieldReference(new Reference(rowType, "a"), 0)), "c", expression(new Reference(rowType, "a"))), limit(1, values("a")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index f1289c0958d..b0a428f93c9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -42,8 +42,8 @@ import io.trino.sql.PlannerContext; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.RuleTester; import io.trino.sql.planner.plan.Assignments; @@ -101,7 +101,7 @@ public void testDoesNotFire() .on(p -> { Symbol symbol = p.symbol(columnName, columnType); return p.project( - Assignments.of(p.symbol("symbol_dereference", BIGINT), new Subscript(symbol.toSymbolReference(), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("symbol_dereference", BIGINT), new FieldReference(symbol.toSymbolReference(), 0)), p.tableScan( ruleTester.getCurrentCatalogTableHandle(TEST_SCHEMA, TEST_TABLE), ImmutableList.of(symbol), @@ -136,7 +136,7 @@ public void testPushProjection() // Prepare project node assignments Assignments inputProjections = Assignments.builder() .put(identity, baseColumn.toSymbolReference()) - .put(dereference, new Subscript(baseColumn.toSymbolReference(), new Constant(INTEGER, 1L))) + .put(dereference, new FieldReference(baseColumn.toSymbolReference(), 0)) .put(constant, new Constant(INTEGER, 5L)) .build(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index 0cd66df9588..eedecbf42db 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -22,8 +22,8 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -96,7 +96,7 @@ public void test() return p.project( Assignments.of( cTimes3, new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, c.toSymbolReference(), new Constant(BIGINT, 3L)), - dX, new Subscript(new Reference(ROW_TYPE, "d"), new Constant(INTEGER, 1L))), + dX, new FieldReference(new Reference(ROW_TYPE, "d"), 0)), p.union( ImmutableListMultimap.builder() .put(c, a) @@ -111,10 +111,10 @@ dX, new Subscript(new Reference(ROW_TYPE, "d"), new Constant(INTEGER, 1L))), .matches( union( project( - ImmutableMap.of("a_times_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Constant(BIGINT, 3L))), "z_x", expression(new Subscript(new Reference(ROW_TYPE, "z"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("a_times_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "a"), new Constant(BIGINT, 3L))), "z_x", expression(new FieldReference(new Reference(ROW_TYPE, "z"), 0))), values(ImmutableList.of("a", "z"))), project( - ImmutableMap.of("b_times_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 3L))), "w_x", expression(new Subscript(new Reference(ROW_TYPE, "w"), new Constant(INTEGER, 1L)))), + ImmutableMap.of("b_times_3", expression(new Arithmetic(MULTIPLY_BIGINT, MULTIPLY, new Reference(BIGINT, "b"), new Constant(BIGINT, 3L))), "w_x", expression(new FieldReference(new Reference(ROW_TYPE, "w"), 0))), values(ImmutableList.of("b", "w")))) .withNumberOfOutputColumns(2) .withAlias("a_times_3") diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java index 8328f14dd0f..73ff5d53ecb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java @@ -21,9 +21,8 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Booleans; -import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -33,7 +32,6 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.Arithmetic.Operator.ADD; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -161,8 +159,8 @@ public void testDoesntPushDownTopNThroughExclusiveDereferences() ImmutableList.of(p.symbol("c")), p.project( Assignments.builder() - .put(p.symbol("b"), new Subscript(a.toSymbolReference(), new Constant(INTEGER, 1L))) - .put(p.symbol("c"), new Subscript(a.toSymbolReference(), new Constant(INTEGER, 2L))) + .put(p.symbol("b"), new FieldReference(a.toSymbolReference(), 0)) + .put(p.symbol("c"), new FieldReference(a.toSymbolReference(), 1)) .build(), p.values(a))); }).doesNotFire(); @@ -180,7 +178,7 @@ public void testPushTopNThroughOverlappingDereferences() ImmutableList.of(d), p.project( Assignments.builder() - .put(p.symbol("b"), new Subscript(a.toSymbolReference(), new Constant(INTEGER, 1L))) + .put(p.symbol("b"), new FieldReference(a.toSymbolReference(), 0)) .put(p.symbol("c", rowType), a.toSymbolReference()) .putIdentity(d) .build(), @@ -188,7 +186,7 @@ public void testPushTopNThroughOverlappingDereferences() }) .matches( project( - ImmutableMap.of("b", expression(new Subscript(new Reference(rowType, "a"), new Constant(INTEGER, 1L))), "c", expression(new Reference(BIGINT, "a")), "d", expression(new Reference(BIGINT, "d"))), + ImmutableMap.of("b", expression(new FieldReference(new Reference(rowType, "a"), 0)), "c", expression(new Reference(BIGINT, "a")), "d", expression(new Reference(BIGINT, "d"))), topN( 1, ImmutableList.of(sort("d", ASCENDING, FIRST)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java index a9cfbe89fac..a4ff3466324 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java @@ -17,8 +17,8 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Row; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; @@ -41,31 +41,31 @@ public class TestUnwrapRowSubscript @Test public void testSimpleSubscript() { - test(new Subscript(new Row(ImmutableList.of(new Constant(INTEGER, 1L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); - test(new Subscript(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); - test(new Subscript(new Subscript(new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), new Constant(INTEGER, 1L)), new Constant(INTEGER, 2L)), new Constant(INTEGER, 2L)); + test(new FieldReference(new Row(ImmutableList.of(new Constant(INTEGER, 1L))), 0), new Constant(INTEGER, 1L)); + test(new FieldReference(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), 0), new Constant(INTEGER, 1L)); + test(new FieldReference(new FieldReference(new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), 0), 1), new Constant(INTEGER, 2L)); } @Test public void testWithCast() { test( - new Subscript(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT))), new Constant(INTEGER, 1L)), + new FieldReference(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT))), 0), new Cast(new Constant(INTEGER, 1L), BIGINT)); test( - new Subscript(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT)), new Constant(INTEGER, 1L)), + new FieldReference(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT)), 0), new Cast(new Constant(INTEGER, 1L), BIGINT)); test( - new Subscript( - new Cast(new Subscript( + new FieldReference( + new Cast(new FieldReference( new Cast( new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), anonymousRow(anonymousRow(SMALLINT, SMALLINT), BIGINT)), - new Constant(INTEGER, 1L)), + 0), rowType(field("x", BIGINT), field("y", BIGINT))), - new Constant(INTEGER, 2L)), + 1), new Cast(new Cast(new Constant(INTEGER, 2L), SMALLINT), BIGINT)); } @@ -73,15 +73,15 @@ public void testWithCast() public void testWithTryCast() { test( - new Subscript(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT)), true), new Constant(INTEGER, 1L)), + new FieldReference(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), rowType(field("a", BIGINT), field("b", BIGINT)), true), 0), new Cast(new Constant(INTEGER, 1L), BIGINT, true)); test( - new Subscript(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT), true), new Constant(INTEGER, 1L)), + new FieldReference(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), anonymousRow(BIGINT, BIGINT), true), 0), new Cast(new Constant(INTEGER, 1L), BIGINT, true)); test( - new Subscript(new Cast(new Subscript(new Cast(new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), anonymousRow(anonymousRow(SMALLINT, SMALLINT), BIGINT), true), new Constant(INTEGER, 1L)), rowType(field("x", BIGINT), field("y", BIGINT)), true), new Constant(INTEGER, 2L)), + new FieldReference(new Cast(new FieldReference(new Cast(new Row(ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), new Constant(INTEGER, 3L))), anonymousRow(anonymousRow(SMALLINT, SMALLINT), BIGINT), true), 0), rowType(field("x", BIGINT), field("y", BIGINT)), true), 1), new Cast(new Cast(new Constant(INTEGER, 2L), SMALLINT, true), BIGINT, true)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java index 590ab0a2f8c..08c8289d33c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapSingleColumnRowInApply.java @@ -15,9 +15,8 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.type.RowType; -import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.SetExpressionMatcher; @@ -83,13 +82,13 @@ public void testUnwrapInPredicate() .buildOrThrow(), project( ImmutableMap.builder() - .put("unwrappedValue", expression(new Subscript(new Reference(RowType.anonymousRow(INTEGER), "rowValue"), new Constant(INTEGER, 1L)))) + .put("unwrappedValue", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER), "rowValue"), 0))) .put("nonRowValue", expression(new Reference(INTEGER, "nonRowValue"))) .buildOrThrow(), values("rowValue", "nonRowValue")), project( ImmutableMap.builder() - .put("unwrappedElement", expression(new Subscript(new Reference(RowType.anonymousRow(INTEGER), "rowElement"), new Constant(INTEGER, 1L)))) + .put("unwrappedElement", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER), "rowElement"), 0))) .put("nonRowElement", expression(new Reference(INTEGER, "nonRowElement"))) .buildOrThrow(), values("rowElement", "nonRowElement"))))); @@ -121,13 +120,13 @@ public void testUnwrapQuantifiedComparison() .buildOrThrow(), project( ImmutableMap.builder() - .put("unwrappedValue", expression(new Subscript(new Reference(RowType.anonymousRow(INTEGER), "rowValue"), new Constant(INTEGER, 1L)))) + .put("unwrappedValue", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER), "rowValue"), 0))) .put("nonRowValue", expression(new Reference(INTEGER, "nonRowValue"))) .buildOrThrow(), values("rowValue", "nonRowValue")), project( ImmutableMap.builder() - .put("unwrappedElement", expression(new Subscript(new Reference(RowType.anonymousRow(INTEGER), "rowElement"), new Constant(INTEGER, 1L)))) + .put("unwrappedElement", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER), "rowElement"), 0))) .put("nonRowElement", expression(new Reference(INTEGER, "nonRowElement"))) .buildOrThrow(), values("rowElement", "nonRowElement"))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java index d67dc433b9c..6bcffa25614 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java @@ -29,8 +29,8 @@ import io.trino.spi.type.RowType; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -208,9 +208,9 @@ public void testNestedField() topN(1, ImmutableList.of(sort("k", ASCENDING, LAST)), FINAL, anyTree( limit(1, ImmutableList.of(), true, ImmutableList.of("k"), - project(ImmutableMap.of("k", expression(new Subscript(new Reference(RowType.from(ImmutableList.of(RowType.field("k", INTEGER))), "nested"), new Constant(INTEGER, 1L)))), + project(ImmutableMap.of("k", expression(new FieldReference(new Reference(RowType.from(ImmutableList.of(RowType.field("k", INTEGER))), "nested"), 0))), filter( - new Comparison(EQUAL, new Subscript(new Reference(RowType.from(ImmutableList.of(RowType.field("k", INTEGER))), "nested"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), + new Comparison(EQUAL, new FieldReference(new Reference(RowType.from(ImmutableList.of(RowType.field("k", INTEGER))), "nested"), 0), new Constant(INTEGER, 1L)), tableScan("with_nested_field", ImmutableMap.of("nested", "nested"))))))))); } } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java index 9fdbdf66e47..585b5cff4e6 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java @@ -2161,7 +2161,7 @@ public void testProjectionPushdownExplain() sessionWithoutPushdown, "EXPLAIN SELECT root.f2 FROM " + tableName, "ScanProject\\[table = (.*)]", - "expr := root\\[integer '2']", + "expr := root.1", "root := root:row\\(f1 bigint, f2 bigint\\):REGULAR"); assertUpdate("DROP TABLE " + tableName); @@ -2178,7 +2178,7 @@ public void testProjectionPushdownNonPrimitiveTypeExplain() assertExplain( "EXPLAIN SELECT id, _row.child, _array[1].child, _map[1] FROM " + tableName, "ScanProject\\[table = (.*)]", - "expr(.*) := system\\.builtin\\.\\$operator\\$subscript\\(_array_.*, bigint '1'\\)\\[integer '1']", + "expr(.*) := system\\.builtin\\.\\$operator\\$subscript\\(_array_.*, bigint '1'\\).0", "id(.*) := id:bigint:REGULAR", // _array:array\\(row\\(child bigint\\)\\) is a symbol name, not a dereference expression. "_array(.*) := _array:array\\(row\\(child bigint\\)\\):REGULAR", diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java index cd3f692b848..6e451ff4d86 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeProjectionPushdownPlans.java @@ -34,9 +34,9 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -148,7 +148,7 @@ public void testPushdownDisabled() session, any( project( - ImmutableMap.of("expr", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 0)), "expr_2", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 1))), tableScan(testTable, ImmutableMap.of("col0", "col0"))))); } @@ -226,9 +226,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), 0)), "expr_0", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), - "expr_0_y", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_y", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), 1))), join(INNER, builder -> { PlanMatchPattern source = tableScan( table -> { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java index ab5ce332085..77b8dde6e3f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java @@ -40,9 +40,9 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; import io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -195,7 +195,7 @@ public void testProjectionPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_deref", BIGINT), new Subscript(p.symbol("struct_of_int", baseType).toSymbolReference(), new Constant(INTEGER, 1L))), + p.symbol("expr_deref", BIGINT), new FieldReference(p.symbol("struct_of_int", baseType).toSymbolReference(), 0)), p.tableScan( table, ImmutableList.of(p.symbol("struct_of_int", baseType)), @@ -336,12 +336,12 @@ public void testPushdownWithDuplicateExpressions() // Test Dereference pushdown tester().assertThat(pushProjectionIntoTableScan) .on(p -> { - Subscript subscript = new Subscript(p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), new Constant(INTEGER, 1L)); - Expression sum = new Arithmetic(ADD_BIGINT, ADD, subscript, new Constant(BIGINT, 2L)); + FieldReference fieldReference = new FieldReference(p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), 0); + Expression sum = new Arithmetic(ADD_BIGINT, ADD, fieldReference, new Constant(BIGINT, 2L)); return p.project( Assignments.of( // The subscript expression instance is part of both the assignments - p.symbol("expr_deref", BIGINT), subscript, + p.symbol("expr_deref", BIGINT), fieldReference, p.symbol("expr_deref_2", BIGINT), sum), p.tableScan( table, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java index db16b110c32..095e536d59f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java @@ -37,9 +37,9 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.testing.PlanTester; import org.junit.jupiter.api.AfterAll; @@ -134,8 +134,8 @@ public void testPushdownDisabled() any( project( ImmutableMap.of( - "expr", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), - "expr_2", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + "expr", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 0)), + "expr_2", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 1))), tableScan(testTable, ImmutableMap.of("col0", "col0"))))); } @@ -214,9 +214,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), 0)), "expr_0", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), - "expr_0_y", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_y", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), 1))), join(INNER, builder -> builder .equiCriteria("t_expr_1", "s_expr_1") .left( diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java index 3cd70074706..7d079c95416 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergProjectionPushdownPlans.java @@ -34,9 +34,9 @@ import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.testing.PlanTester; import org.junit.jupiter.api.AfterAll; @@ -140,7 +140,7 @@ public void testPushdownDisabled() session, any( project( - ImmutableMap.of("expr", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 0)), "expr_2", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 1))), tableScan(testTable, ImmutableMap.of("col0", "col0"))))); } @@ -226,9 +226,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), 0)), "expr_0", expression(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0")), - "expr_0_y", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_y", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "expr_0"), 1))), join(INNER, builder -> builder .equiCriteria("s_expr_1", "t_expr_1") .left( diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java index 30a7a8b66a6..f6946df6a5c 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java @@ -43,9 +43,9 @@ import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Negation; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; import io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -220,7 +220,7 @@ public void testProjectionPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_deref", BIGINT), new Subscript(p.symbol("struct_of_int", baseType).toSymbolReference(), new Constant(INTEGER, 1L))), + p.symbol("expr_deref", BIGINT), new FieldReference(p.symbol("struct_of_int", baseType).toSymbolReference(), 0)), p.tableScan( table, ImmutableList.of(p.symbol("struct_of_int", baseType)), @@ -414,12 +414,12 @@ public void testPushdownWithDuplicateExpressions() // Test Dereference pushdown tester().assertThat(pushProjectionIntoTableScan) .on(p -> { - Subscript subscript = new Subscript(p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), new Constant(INTEGER, 1L)); - Expression sum = new Arithmetic(ADD_BIGINT, ADD, subscript, new Constant(BIGINT, 2L)); + FieldReference fieldReference = new FieldReference(p.symbol("struct_of_bigint", ROW_TYPE).toSymbolReference(), 0); + Expression sum = new Arithmetic(ADD_BIGINT, ADD, fieldReference, new Constant(BIGINT, 2L)); return p.project( Assignments.of( // The subscript expression instance is part of both the assignments - p.symbol("expr_deref", BIGINT), subscript, + p.symbol("expr_deref", BIGINT), fieldReference, p.symbol("expr_deref_2", BIGINT), sum), p.tableScan( table, diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java index 0f85bad7984..a4c46d1c206 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java @@ -32,8 +32,8 @@ import io.trino.sql.ir.Arithmetic; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Reference; -import io.trino.sql.ir.Subscript; import io.trino.sql.planner.assertions.BasePushdownPlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.PlanTester; @@ -133,7 +133,7 @@ public void testPushdownDisabled() session, any( project( - ImmutableMap.of("expr_1", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 1L))), "expr_2", expression(new Subscript(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), new Constant(INTEGER, 2L)))), + ImmutableMap.of("expr_1", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 0)), "expr_2", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, BIGINT), "col0"), 1))), tableScan(tableName, ImmutableMap.of("col0", "col0"))))); } @@ -205,9 +205,9 @@ public void testDereferencePushdown() anyTree( project( ImmutableMap.of( - "expr_0_x", expression(new Subscript(new Reference(RowType.anonymousRow(INTEGER), "expr_0"), new Constant(INTEGER, 1L))), + "expr_0_x", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER), "expr_0"), 0)), "expr_0", expression(new Reference(RowType.anonymousRow(INTEGER), "expr_0")), - "expr_0_y", expression(new Subscript(new Reference(RowType.anonymousRow(INTEGER, INTEGER), "expr_0"), new Constant(INTEGER, 2L)))), + "expr_0_y", expression(new FieldReference(new Reference(RowType.anonymousRow(INTEGER, INTEGER), "expr_0"), 1))), PlanMatchPattern.join(INNER, builder -> builder .equiCriteria("t_expr_1", "s_expr_1") .left(