Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;

import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
Expand Down Expand Up @@ -48,10 +49,9 @@ public String getValue()

public Arithmetic
{
requireNonNull(function, "function is null");
requireNonNull(operator, "operator is null");
requireNonNull(left, "left is null");
requireNonNull(right, "right is null");
validateType(function.getSignature().getArgumentType(0), left);
validateType(function.getSignature().getArgumentType(1), right);
}

@Override
Expand Down
7 changes: 3 additions & 4 deletions core/trino-main/src/main/java/io/trino/sql/ir/Between.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.List;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.util.Objects.requireNonNull;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record Between(Expression value, Expression min, Expression max)
Expand All @@ -30,9 +30,8 @@ public record Between(Expression value, Expression min, Expression max)
@JsonCreator
public Between
{
requireNonNull(value, "value is null");
requireNonNull(min, "min is null");
requireNonNull(max, "max is null");
validateType(value.type(), min);
validateType(value.type(), max);
}

@Override
Expand Down
6 changes: 4 additions & 2 deletions core/trino-main/src/main/java/io/trino/sql/ir/Bind.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import java.util.stream.Collectors;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
import static io.trino.sql.ir.IrUtils.validateType;

/**
* Bind(value, targetFunction)
Expand Down Expand Up @@ -54,8 +54,10 @@ public record Bind(List<Expression> values, Lambda function)
{
public Bind
{
requireNonNull(function, "function is null");
values = ImmutableList.copyOf(values);
for (int i = 0; i < values.size(); i++) {
validateType(function.arguments().get(i).getType(), values.get(i));
}
}

@Override
Expand Down
9 changes: 7 additions & 2 deletions core/trino-main/src/main/java/io/trino/sql/ir/Call.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@
import java.util.List;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record Call(ResolvedFunction function, List<Expression> arguments)
implements Expression
{
public Call
{
requireNonNull(function, "function is null");
arguments = ImmutableList.copyOf(arguments);

checkArgument(function.getSignature().getArgumentTypes().size() == arguments.size(), "Expected %s arguments, found: %s", function.getSignature().getArgumentTypes().size(), arguments.size());
for (int i = 0; i < arguments.size(); i++) {
validateType(function.getSignature().getArgumentType(i), arguments.get(i));
}
}

@Override
Expand Down
14 changes: 14 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/ir/Case.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Optional;
import java.util.stream.Collectors;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
Expand All @@ -31,6 +33,18 @@ public record Case(List<WhenClause> whenClauses, Optional<Expression> defaultVal
{
whenClauses = ImmutableList.copyOf(whenClauses);
requireNonNull(defaultValue, "defaultValue is null");

for (WhenClause clause : whenClauses) {
validateType(BOOLEAN, clause.getOperand());
}

for (int i = 1; i < whenClauses.size(); i++) {
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
}

if (defaultValue.isPresent()) {
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
}
}

@Override
Expand Down
5 changes: 5 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/ir/Coalesce.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record Coalesce(List<Expression> operands)
Expand All @@ -43,6 +44,10 @@ public Type type()
{
checkArgument(operands.size() >= 2, "must have at least two operands");
operands = ImmutableList.copyOf(operands);

for (int i = 1; i < operands.size(); i++) {
validateType(operands.getFirst().type(), operands.get(i));
}
}

@Override
Expand Down
4 changes: 2 additions & 2 deletions core/trino-main/src/main/java/io/trino/sql/ir/Comparison.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
Expand Down Expand Up @@ -87,8 +88,7 @@ public Operator negate()
public Comparison
{
requireNonNull(operator, "operator is null");
requireNonNull(left, "left is null");
requireNonNull(right, "right is null");
validateType(left.type(), right);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should match the operator... matching each other it nice also

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have the ResolvedFunction in Comparison yet. Once I add it, we can also validate against it.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ protected Expression visitSubscript(Subscript node, Context<C> context)
Expression index = rewrite(node.index(), context.get());

if (base != node.base() || index != node.index()) {
return new Subscript(node.type(), base, index);
return new Subscript(base, index);
}

return node;
Expand Down
5 changes: 5 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/ir/In.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record In(Expression value, List<Expression> valueList)
Expand All @@ -28,6 +29,10 @@ public record In(Expression value, List<Expression> valueList)
public In
{
valueList = ImmutableList.copyOf(valueList);

for (Expression item : valueList) {
validateType(value.type(), item);
}
}

@Override
Expand Down
7 changes: 7 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.graph.SuccessorsFunction;
import com.google.common.graph.Traverser;
import io.trino.metadata.Metadata;
import io.trino.spi.type.Type;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
Expand All @@ -31,6 +32,7 @@
import java.util.function.Predicate;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Streams.stream;
Expand All @@ -43,6 +45,11 @@ public final class IrUtils
{
private IrUtils() {}

static void validateType(Type expected, Expression expression)
{
checkArgument(expected.equals(expression.type()), "Expected '%s' type but found '%s' for expression: %s", expected, expression.type(), expression);
}

public static List<Expression> extractConjuncts(Expression expression)
{
return extractPredicates(Logical.Operator.AND, expression);
Expand Down
6 changes: 6 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/ir/Logical.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
Expand All @@ -48,6 +49,11 @@ public Operator flip()
{
requireNonNull(operator, "operator is null");
checkArgument(terms.size() >= 2, "Expected at least 2 terms");

for (Expression term : terms) {
validateType(BOOLEAN, term);
}

terms = ImmutableList.copyOf(terms);
}

Expand Down
4 changes: 2 additions & 2 deletions core/trino-main/src/main/java/io/trino/sql/ir/Not.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
import java.util.List;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.util.Objects.requireNonNull;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record Not(Expression value)
implements Expression
{
public Not
{
requireNonNull(value, "value is null");
validateType(BOOLEAN, value);
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/ir/NullIf.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public record NullIf(Expression first, Expression second)
{
requireNonNull(first, "first is null");
requireNonNull(second, "second is null");

// TODO: verify that first and second can be coerced to the same type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could add the Optional<ResolvedFunction> for the coercion and then the verification is easier.. also you likely need the resolved equals function here

}

@Override
Expand Down
19 changes: 14 additions & 5 deletions core/trino-main/src/main/java/io/trino/sql/ir/Subscript.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,35 @@

import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.collect.ImmutableList;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;

import java.util.List;

import static java.util.Objects.requireNonNull;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.sql.ir.IrUtils.validateType;

@JsonSerialize
public record Subscript(Type type, Expression base, Expression index)
public record Subscript(Expression base, Expression index)
implements Expression
{
public Subscript
{
requireNonNull(base, "base is null");
requireNonNull(index, "index is null");
if (!(base.type() instanceof RowType rowType)) {
throw new IllegalArgumentException("Expected 'row' type but found '%s' for expression: %s".formatted(base.type(), base));
}

validateType(INTEGER, index);
int field = (int) (long) ((Constant) index).value() - 1;
checkArgument(field < rowType.getFields().size(), "Expected 'row' type to have at least %s fields, but has: %s", field + 1, rowType.getFields().size());
}

@Override
public Type type()
{
return type;
int field = (int) (long) ((Constant) index).value() - 1;
return ((RowType) base.type()).getFields().get(field).getType();
}

@Override
Expand Down
14 changes: 13 additions & 1 deletion core/trino-main/src/main/java/io/trino/sql/ir/Switch.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Optional;

import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
Expand All @@ -30,7 +31,18 @@ public record Switch(Expression operand, List<WhenClause> whenClauses, Optional<
{
requireNonNull(operand, "operand is null");
whenClauses = ImmutableList.copyOf(whenClauses);
requireNonNull(defaultValue, "defaultValue is null");

for (WhenClause clause : whenClauses) {
validateType(operand.type(), clause.getOperand());
}

for (int i = 1; i < whenClauses.size(); i++) {
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
}

if (defaultValue.isPresent()) {
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public Optional<Expression> translate(ConnectorExpression expression)

if (expression instanceof FieldDereference dereference) {
return translate(dereference.getTarget())
.map(base -> new Subscript(dereference.getType(), base, new Constant(INTEGER, (long) (dereference.getField() + 1))));
.map(base -> new Subscript(base, new Constant(INTEGER, (long) (dereference.getField() + 1))));
}

if (expression instanceof io.trino.spi.expression.Call call) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ protected Object visitSubscript(Subscript node, Object context)
}

if (hasUnresolvedValue(base, index)) {
return new Subscript(node.type(), toExpression(base, node.base().type()), toExpression(index, node.index().type()));
return new Subscript(toExpression(base, node.base().type()), toExpression(index, node.index().type()));
}

// Subscript on Row hasn't got a dedicated operator. It is interpreted by hand.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ public MergeWriterNode plan(Merge merge)
subPlanProject,
Assignments.builder()
.putIdentities(subPlanProject.getOutputSymbols())
.put(caseNumberSymbol, new Subscript(INTEGER, mergeRowSymbol.toSymbolReference(), new Constant(INTEGER, (long) mergeAnalysis.getMergeRowType().getFields().size())))
.put(caseNumberSymbol, new Subscript(mergeRowSymbol.toSymbolReference(), new Constant(INTEGER, (long) mergeAnalysis.getMergeRowType().getFields().size())))
.build());

// Mark distinct combinations of the unique_id value and the case_number
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ private io.trino.sql.ir.Expression translate(DereferenceExpression expression)
checkState(index >= 0, "could not find field name: %s", fieldName);

return new Subscript(
rowType.getFields().get(index).getType(),
translateExpression(expression.getBase()),
new Constant(INTEGER, (long) (index + 1)));
}
Expand Down Expand Up @@ -959,7 +958,6 @@ private io.trino.sql.ir.Expression translate(SubscriptExpression node)
io.trino.sql.ir.Expression rewrittenBase = translateExpression(node.getBase());
LongLiteral index = (LongLiteral) node.getIndex();
return new Subscript(
analysis.getType(node),
rewrittenBase, new Constant(INTEGER, index.getParsedValue()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ private Optional<Unwrapping> unwrapSingleColumnRow(Context context, Expression v
Symbol valueSymbol = context.getSymbolAllocator().newSymbol("input", elementType);
Symbol listSymbol = context.getSymbolAllocator().newSymbol("subquery", elementType);

Assignment inputAssignment = new Assignment(valueSymbol, new Subscript(elementType, value, new Constant(INTEGER, 1L)));
Assignment nestedPlanAssignment = new Assignment(listSymbol, new Subscript(elementType, list, new Constant(INTEGER, 1L)));
Assignment inputAssignment = new Assignment(valueSymbol, new Subscript(value, new Constant(INTEGER, 1L)));
Assignment nestedPlanAssignment = new Assignment(listSymbol, new Subscript(list, new Constant(INTEGER, 1L)));
ApplyNode.SetExpression comparison = function.apply(valueSymbol, listSymbol);

return Optional.of(new Unwrapping(comparison, inputAssignment, nestedPlanAssignment));
Expand Down
Loading