Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ protected Void visitCoalesce(Coalesce node, C context)
}

@Override
protected Void visitSubscript(Subscript node, C context)
protected Void visitFieldReference(FieldReference node, C context)
{
process(node.base(), context);
process(node.index(), context);

return null;
}
Expand Down
6 changes: 3 additions & 3 deletions core/trino-main/src/main/java/io/trino/sql/ir/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
@JsonSubTypes.Type(value = Row.class, name = "row"),
@JsonSubTypes.Type(value = Case.class, name = "case"),
@JsonSubTypes.Type(value = Switch.class, name = "switch"),
@JsonSubTypes.Type(value = Subscript.class, name = "subscript"),
@JsonSubTypes.Type(value = FieldReference.class, name = "field"),
@JsonSubTypes.Type(value = Reference.class, name = "reference"),
})
public sealed interface Expression
permits Arithmetic, Between, Bind, Call, Case, Cast, Coalesce,
Comparison, Constant, In, IsNull, Lambda, Logical, Negation,
Not, NullIf, Reference, Row, Subscript, Switch
Comparison, Constant, FieldReference, In, IsNull, Lambda, Logical,
Negation, Not, NullIf, Reference, Row, Switch
{
Type type();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ protected String visitExpression(Expression node, Void context)
}

@Override
protected String visitSubscript(Subscript node, Void context)
protected String visitFieldReference(FieldReference node, Void context)
{
return formatExpression(node.base()) + "[" + formatExpression(node.index()) + "]";
return formatExpression(node.base()) + "." + node.field();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public Expression rewriteConstant(Constant node, C context, ExpressionTreeRewrit
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteSubscript(Subscript node, C context, ExpressionTreeRewriter<C> treeRewriter)
public Expression rewriteSubscript(FieldReference node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public Expression visitArithmetic(Arithmetic node, Context<C> context)
}

@Override
protected Expression visitSubscript(Subscript node, Context<C> context)
protected Expression visitFieldReference(FieldReference node, Context<C> context)
{
if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteSubscript(node, context.get(), ExpressionTreeRewriter.this);
Expand All @@ -146,10 +146,9 @@ protected Expression visitSubscript(Subscript node, Context<C> context)
}

Expression base = rewrite(node.base(), context.get());
Expression index = rewrite(node.index(), context.get());

if (base != node.base() || index != node.index()) {
return new Subscript(base, index);
if (base != node.base()) {
return new FieldReference(base, node.field());
}

return node;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,35 @@
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record Subscript(Expression base, Expression index)
public record FieldReference(Expression base, int field)
implements Expression
{
public Subscript
public FieldReference
{
if (!(base.type() instanceof RowType rowType)) {
throw new IllegalArgumentException("Expected 'row' type but found '%s' for expression: %s".formatted(base.type(), base));
}

validateType(INTEGER, index);
int field = (int) (long) ((Constant) index).value() - 1;
checkArgument(field < rowType.getFields().size(), "Expected 'row' type to have at least %s fields, but has: %s", field + 1, rowType.getFields().size());
}

@Override
public Type type()
{
int field = (int) (long) ((Constant) index).value() - 1;
return ((RowType) base.type()).getFields().get(field).getType();
}

@Override
public <R, C> R accept(IrVisitor<R, C> visitor, C context)
{
return visitor.visitSubscript(this, context);
return visitor.visitFieldReference(this, context);
}

@Override
public List<? extends Expression> children()
{
return ImmutableList.of(base, index);
return ImmutableList.of(base);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ protected R visitIsNull(IsNull node, C context)
return visitExpression(node, context);
}

protected R visitSubscript(Subscript node, C context)
protected R visitFieldReference(FieldReference node, C context)
{
return visitExpression(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
Expand All @@ -51,7 +52,6 @@
import io.trino.sql.ir.Not;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Subscript;
import io.trino.sql.tree.QualifiedName;
import io.trino.type.JoniRegexp;
import io.trino.type.JsonPathType;
Expand Down Expand Up @@ -94,7 +94,6 @@
import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.DynamicFilters.isDynamicFilterFunction;
Expand Down Expand Up @@ -208,7 +207,7 @@ public Optional<Expression> translate(ConnectorExpression expression)

if (expression instanceof FieldDereference dereference) {
return translate(dereference.getTarget())
.map(base -> new Subscript(base, new Constant(INTEGER, (long) (dereference.getField() + 1))));
.map(base -> new FieldReference(base, dereference.getField()));
}

if (expression instanceof io.trino.spi.expression.Call call) {
Expand Down Expand Up @@ -793,7 +792,7 @@ protected Optional<ConnectorExpression> visitNullIf(NullIf node, Void context)
}

@Override
protected Optional<ConnectorExpression> visitSubscript(Subscript node, Void context)
protected Optional<ConnectorExpression> visitFieldReference(FieldReference node, Void context)
{
if (!(node.base().type() instanceof RowType)) {
return Optional.empty();
Expand All @@ -804,7 +803,7 @@ protected Optional<ConnectorExpression> visitSubscript(Subscript node, Void cont
return Optional.empty();
}

return Optional.of(new FieldDereference(((Expression) node).type(), translatedBase.get(), (int) ((long) ((Constant) node.index()).value() - 1)));
return Optional.of(new FieldDereference(((Expression) node).type(), translatedBase.get(), node.field()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.scalar.ArraySubscriptOperator;
import io.trino.spi.TrinoException;
import io.trino.spi.block.SqlRow;
import io.trino.spi.connector.ConnectorSession;
Expand All @@ -43,6 +42,7 @@
import io.trino.sql.ir.Comparison.Operator;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
Expand All @@ -53,7 +53,6 @@
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.Subscript;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.type.FunctionType;
Expand Down Expand Up @@ -82,7 +81,6 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.block.RowValueBuilder.buildRowValue;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
Expand All @@ -95,7 +93,6 @@
import static io.trino.sql.DynamicFilters.isDynamicFilter;
import static io.trino.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -883,36 +880,20 @@ protected Object visitRow(Row node, Object context)
}

@Override
protected Object visitSubscript(Subscript node, Object context)
protected Object visitFieldReference(FieldReference node, Object context)
{
Object base = processWithExceptionHandling(node.base(), context);
if (base == null) {
return null;
}
Object index = processWithExceptionHandling(node.index(), context);
if (index == null) {
return null;
}
if ((index instanceof Long) && isArray(node.base().type())) {
ArraySubscriptOperator.checkArrayIndex((Long) index);
}

if (hasUnresolvedValue(base, index)) {
return new Subscript(toExpression(base, node.base().type()), toExpression(index, node.index().type()));
}

// Subscript on Row hasn't got a dedicated operator. It is interpreted by hand.
if (base instanceof SqlRow row) {
int fieldIndex = toIntExact((long) index - 1);
if (fieldIndex < 0 || fieldIndex >= row.getFieldCount()) {
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "ROW index out of bounds: " + (fieldIndex + 1));
}
Type returnType = node.base().type().getTypeParameters().get(fieldIndex);
return readNativeValue(returnType, row.getRawFieldBlock(fieldIndex), row.getRawIndex());
if (hasUnresolvedValue(base)) {
return new FieldReference(toExpression(base, node.base().type()), node.field());
}

// Subscript on Array or Map is interpreted using operator.
return invokeOperator(OperatorType.SUBSCRIPT, types(node.base(), node.index()), ImmutableList.of(base, index));
SqlRow row = (SqlRow) base;
Type returnType = node.base().type().getTypeParameters().get(node.field());
return readNativeValue(returnType, row.getRawFieldBlock(node.field()), row.getRawIndex());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.DefaultTraversalVisitor;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.In;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Subscript;
import io.trino.sql.ir.Switch;

import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -97,7 +97,7 @@ protected Void visitSwitch(Switch node, AtomicBoolean result)
}

@Override
protected Void visitSubscript(Subscript node, AtomicBoolean result)
protected Void visitFieldReference(FieldReference node, AtomicBoolean result)
{
result.set(true);
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.Subscript;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.RelationPlanner.PatternRecognitionComponents;
import io.trino.sql.planner.plan.AggregationNode;
Expand Down Expand Up @@ -81,7 +81,6 @@
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.Delete;
import io.trino.sql.tree.FetchFirst;
import io.trino.sql.tree.FieldReference;
import io.trino.sql.tree.FrameBound;
import io.trino.sql.tree.FunctionCall.NullTreatment;
import io.trino.sql.tree.Join;
Expand Down Expand Up @@ -513,7 +512,7 @@ public PlanNode plan(Delete node)
builder = filter(builder, node.getWhere().get(), node);
}

FieldReference reference = analysis.getRowIdField(table);
io.trino.sql.tree.FieldReference reference = analysis.getRowIdField(table);
Symbol rowIdSymbol = builder.translate(reference);
List<Symbol> outputs = ImmutableList.of(
symbolAllocator.newSymbol("partialrows", BIGINT),
Expand Down Expand Up @@ -655,7 +654,7 @@ public PlanNode plan(Update node)
}
}

FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
io.trino.sql.tree.FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
assignments.putIdentity(relationPlan.getFieldMappings().get(rowIdReference.getFieldIndex()));

// Add the "present" field
Expand Down Expand Up @@ -770,7 +769,7 @@ public MergeWriterNode plan(Merge merge)

PlanBuilder subPlan = newPlanBuilder(joinPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext);

FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
io.trino.sql.tree.FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex());

// Build the SearchedCaseExpression that creates the project merge_row
Expand Down Expand Up @@ -889,7 +888,7 @@ public MergeWriterNode plan(Merge merge)
subPlanProject,
Assignments.builder()
.putIdentities(subPlanProject.getOutputSymbols())
.put(caseNumberSymbol, new Subscript(mergeRowSymbol.toSymbolReference(), new Constant(INTEGER, (long) mergeAnalysis.getMergeRowType().getFields().size())))
.put(caseNumberSymbol, new FieldReference(mergeRowSymbol.toSymbolReference(), mergeAnalysis.getMergeRowType().getFields().size() - 1))
.build());

// Mark distinct combinations of the unique_id value and the case_number
Expand Down
Loading