Skip to content

Commit aa739b2

Browse files
committed
Verify IR types
1 parent 8f279a8 commit aa739b2

File tree

97 files changed

+939
-923
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+939
-923
lines changed

core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import java.util.List;
2222

23+
import static io.trino.sql.ir.IrUtils.validateType;
2324
import static java.util.Objects.requireNonNull;
2425

2526
@JsonSerialize
@@ -48,10 +49,9 @@ public String getValue()
4849

4950
public Arithmetic
5051
{
51-
requireNonNull(function, "function is null");
5252
requireNonNull(operator, "operator is null");
53-
requireNonNull(left, "left is null");
54-
requireNonNull(right, "right is null");
53+
validateType(function.getSignature().getArgumentType(0), left);
54+
validateType(function.getSignature().getArgumentType(1), right);
5555
}
5656

5757
@Override

core/trino-main/src/main/java/io/trino/sql/ir/Between.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.util.List;
2222

2323
import static io.trino.spi.type.BooleanType.BOOLEAN;
24-
import static java.util.Objects.requireNonNull;
24+
import static io.trino.sql.ir.IrUtils.validateType;
2525

2626
@JsonSerialize
2727
public record Between(Expression value, Expression min, Expression max)
@@ -30,9 +30,8 @@ public record Between(Expression value, Expression min, Expression max)
3030
@JsonCreator
3131
public Between
3232
{
33-
requireNonNull(value, "value is null");
34-
requireNonNull(min, "min is null");
35-
requireNonNull(max, "max is null");
33+
validateType(value.type(), min);
34+
validateType(value.type(), max);
3635
}
3736

3837
@Override

core/trino-main/src/main/java/io/trino/sql/ir/Bind.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import java.util.stream.Collectors;
2525

2626
import static com.google.common.collect.ImmutableList.toImmutableList;
27-
import static java.util.Objects.requireNonNull;
27+
import static io.trino.sql.ir.IrUtils.validateType;
2828

2929
/**
3030
* Bind(value, targetFunction)
@@ -54,8 +54,10 @@ public record Bind(List<Expression> values, Lambda function)
5454
{
5555
public Bind
5656
{
57-
requireNonNull(function, "function is null");
5857
values = ImmutableList.copyOf(values);
58+
for (int i = 0; i < values.size(); i++) {
59+
validateType(function.arguments().get(i).getType(), values.get(i));
60+
}
5961
}
6062

6163
@Override

core/trino-main/src/main/java/io/trino/sql/ir/Call.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@
2121
import java.util.List;
2222
import java.util.stream.Collectors;
2323

24-
import static java.util.Objects.requireNonNull;
24+
import static com.google.common.base.Preconditions.checkArgument;
25+
import static io.trino.sql.ir.IrUtils.validateType;
2526

2627
@JsonSerialize
2728
public record Call(ResolvedFunction function, List<Expression> arguments)
2829
implements Expression
2930
{
3031
public Call
3132
{
32-
requireNonNull(function, "function is null");
3333
arguments = ImmutableList.copyOf(arguments);
34+
35+
checkArgument(function.getSignature().getArgumentTypes().size() == arguments.size(), "Expected %s arguments, found: %s", function.getSignature().getArgumentTypes().size(), arguments.size());
36+
for (int i = 0; i < arguments.size(); i++) {
37+
validateType(function.getSignature().getArgumentType(i), arguments.get(i));
38+
}
3439
}
3540

3641
@Override

core/trino-main/src/main/java/io/trino/sql/ir/Case.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.util.Optional;
2222
import java.util.stream.Collectors;
2323

24+
import static io.trino.spi.type.BooleanType.BOOLEAN;
25+
import static io.trino.sql.ir.IrUtils.validateType;
2426
import static java.util.Objects.requireNonNull;
2527

2628
@JsonSerialize
@@ -31,6 +33,18 @@ public record Case(List<WhenClause> whenClauses, Optional<Expression> defaultVal
3133
{
3234
whenClauses = ImmutableList.copyOf(whenClauses);
3335
requireNonNull(defaultValue, "defaultValue is null");
36+
37+
for (WhenClause clause : whenClauses) {
38+
validateType(BOOLEAN, clause.getOperand());
39+
}
40+
41+
for (int i = 1; i < whenClauses.size(); i++) {
42+
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
43+
}
44+
45+
if (defaultValue.isPresent()) {
46+
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
47+
}
3448
}
3549

3650
@Override

core/trino-main/src/main/java/io/trino/sql/ir/Coalesce.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static com.google.common.base.Preconditions.checkArgument;
23+
import static io.trino.sql.ir.IrUtils.validateType;
2324

2425
@JsonSerialize
2526
public record Coalesce(List<Expression> operands)
@@ -43,6 +44,10 @@ public Type type()
4344
{
4445
checkArgument(operands.size() >= 2, "must have at least two operands");
4546
operands = ImmutableList.copyOf(operands);
47+
48+
for (int i = 1; i < operands.size(); i++) {
49+
validateType(operands.getFirst().type(), operands.get(i));
50+
}
4651
}
4752

4853
@Override

core/trino-main/src/main/java/io/trino/sql/ir/Comparison.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static io.trino.spi.type.BooleanType.BOOLEAN;
23+
import static io.trino.sql.ir.IrUtils.validateType;
2324
import static java.util.Objects.requireNonNull;
2425

2526
@JsonSerialize
@@ -87,8 +88,7 @@ public Operator negate()
8788
public Comparison
8889
{
8990
requireNonNull(operator, "operator is null");
90-
requireNonNull(left, "left is null");
91-
requireNonNull(right, "right is null");
91+
validateType(left.type(), right);
9292
}
9393

9494
@Override

core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ protected Expression visitSubscript(Subscript node, Context<C> context)
149149
Expression index = rewrite(node.index(), context.get());
150150

151151
if (base != node.base() || index != node.index()) {
152-
return new Subscript(node.type(), base, index);
152+
return new Subscript(base, index);
153153
}
154154

155155
return node;

core/trino-main/src/main/java/io/trino/sql/ir/In.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static io.trino.spi.type.BooleanType.BOOLEAN;
23+
import static io.trino.sql.ir.IrUtils.validateType;
2324

2425
@JsonSerialize
2526
public record In(Expression value, List<Expression> valueList)
@@ -28,6 +29,10 @@ public record In(Expression value, List<Expression> valueList)
2829
public In
2930
{
3031
valueList = ImmutableList.copyOf(valueList);
32+
33+
for (Expression item : valueList) {
34+
validateType(value.type(), item);
35+
}
3136
}
3237

3338
@Override

core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.google.common.graph.SuccessorsFunction;
1919
import com.google.common.graph.Traverser;
2020
import io.trino.metadata.Metadata;
21+
import io.trino.spi.type.Type;
2122
import io.trino.sql.planner.DeterminismEvaluator;
2223
import io.trino.sql.planner.Symbol;
2324
import io.trino.sql.planner.SymbolsExtractor;
@@ -31,6 +32,7 @@
3132
import java.util.function.Predicate;
3233
import java.util.stream.Stream;
3334

35+
import static com.google.common.base.Preconditions.checkArgument;
3436
import static com.google.common.base.Predicates.not;
3537
import static com.google.common.collect.ImmutableList.toImmutableList;
3638
import static com.google.common.collect.Streams.stream;
@@ -43,6 +45,11 @@ public final class IrUtils
4345
{
4446
private IrUtils() {}
4547

48+
static void validateType(Type expected, Expression expression)
49+
{
50+
checkArgument(expected.equals(expression.type()), "Expected '%s' type but found '%s' for expression: %s", expected, expression.type(), expression);
51+
}
52+
4653
public static List<Expression> extractConjuncts(Expression expression)
4754
{
4855
return extractPredicates(Logical.Operator.AND, expression);

0 commit comments

Comments
 (0)