Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,34 @@

import com.google.inject.Inject;
import io.trino.Session;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Arithmetic;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.Negation;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;

import java.util.OptionalDouble;

import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.function.OperatorType.ADD;
import static io.trino.spi.function.OperatorType.DIVIDE;
import static io.trino.spi.function.OperatorType.MODULUS;
import static io.trino.spi.function.OperatorType.MULTIPLY;
import static io.trino.spi.function.OperatorType.NEGATION;
import static io.trino.spi.function.OperatorType.SUBTRACT;
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.util.MoreMath.max;
import static io.trino.util.MoreMath.min;
Expand Down Expand Up @@ -109,6 +115,21 @@ protected SymbolStatsEstimate visitConstant(Constant node, Void context)
@Override
protected SymbolStatsEstimate visitCall(Call node, Void context)
{
if (node.function().getName().equals(builtinFunctionName(NEGATION))) {
SymbolStatsEstimate stats = process(node.arguments().getFirst());
return SymbolStatsEstimate.buildFrom(stats)
.setLowValue(-stats.getHighValue())
.setHighValue(-stats.getLowValue())
.build();
}
else if (node.function().getName().equals(builtinFunctionName(ADD)) ||
node.function().getName().equals(builtinFunctionName(SUBTRACT)) ||
node.function().getName().equals(builtinFunctionName(MULTIPLY)) ||
node.function().getName().equals(builtinFunctionName(DIVIDE)) ||
node.function().getName().equals(builtinFunctionName(MODULUS))) {
return processArithmetic(node);
}

IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session);
Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);

Expand Down Expand Up @@ -175,22 +196,11 @@ private boolean isIntegralType(Type type)
return false;
}

@Override
protected SymbolStatsEstimate visitNegation(Negation node, Void context)
{
SymbolStatsEstimate stats = process(node.value());
return SymbolStatsEstimate.buildFrom(stats)
.setLowValue(-stats.getHighValue())
.setHighValue(-stats.getLowValue())
.build();
}

@Override
protected SymbolStatsEstimate visitArithmetic(Arithmetic node, Void context)
protected SymbolStatsEstimate processArithmetic(Call node)
{
requireNonNull(node, "node is null");
SymbolStatsEstimate left = process(node.left());
SymbolStatsEstimate right = process(node.right());
SymbolStatsEstimate left = process(node.arguments().get(0));
SymbolStatsEstimate right = process(node.arguments().get(1));
if (left.isUnknown() || right.isUnknown()) {
return SymbolStatsEstimate.unknown();
}
Expand All @@ -208,11 +218,11 @@ protected SymbolStatsEstimate visitArithmetic(Arithmetic node, Void context)
result.setLowValue(NaN)
.setHighValue(NaN);
}
else if (node.operator() == Arithmetic.Operator.DIVIDE && rightLow < 0 && rightHigh > 0) {
else if (node.function().getName().equals(builtinFunctionName(DIVIDE)) && rightLow < 0 && rightHigh > 0) {
result.setLowValue(Double.NEGATIVE_INFINITY)
.setHighValue(Double.POSITIVE_INFINITY);
}
else if (node.operator() == Arithmetic.Operator.MODULUS) {
else if (node.function().getName().equals(builtinFunctionName(MODULUS))) {
double maxDivisor = max(abs(rightLow), abs(rightHigh));
if (leftHigh <= 0) {
result.setLowValue(max(-maxDivisor, leftLow))
Expand All @@ -228,10 +238,10 @@ else if (leftLow >= 0) {
}
}
else {
double v1 = operate(node.operator(), leftLow, rightLow);
double v2 = operate(node.operator(), leftLow, rightHigh);
double v3 = operate(node.operator(), leftHigh, rightLow);
double v4 = operate(node.operator(), leftHigh, rightHigh);
double v1 = operate(node.function().getName(), leftLow, rightLow);
double v2 = operate(node.function().getName(), leftLow, rightHigh);
double v3 = operate(node.function().getName(), leftHigh, rightLow);
double v4 = operate(node.function().getName(), leftHigh, rightHigh);
double lowValue = min(v1, v2, v3, v4);
double highValue = max(v1, v2, v3, v4);

Expand All @@ -242,21 +252,16 @@ else if (leftLow >= 0) {
return result.build();
}

private double operate(Arithmetic.Operator operator, double left, double right)
private double operate(CatalogSchemaFunctionName function, double left, double right)
{
switch (operator) {
case ADD:
return left + right;
case SUBTRACT:
return left - right;
case MULTIPLY:
return left * right;
case DIVIDE:
return left / right;
case MODULUS:
return left % right;
}
throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator);
return switch (function) {
case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(ADD)) -> left + right;
case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(SUBTRACT)) -> left - right;
case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(MULTIPLY)) -> left * right;
case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(DIVIDE)) -> left / right;
case CatalogSchemaFunctionName name when name.equals(builtinFunctionName(MODULUS)) -> left % right;
default -> throw new IllegalStateException("Unsupported binary arithmetic operation: " + function);
};
}

@Override
Expand Down
80 changes: 0 additions & 80 deletions core/trino-main/src/main/java/io/trino/sql/ir/Arithmetic.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ protected Void visitCast(Cast node, C context)
return null;
}

@Override
protected Void visitArithmetic(Arithmetic node, C context)
{
process(node.left(), context);
process(node.right(), context);

return null;
}

@Override
protected Void visitBetween(Between node, C context)
{
Expand Down Expand Up @@ -125,13 +116,6 @@ protected Void visitBind(Bind node, C context)
return null;
}

@Override
protected Void visitNegation(Negation node, C context)
{
process(node.value(), context);
return null;
}

@Override
protected Void visitNot(Not node, C context)
{
Expand Down
6 changes: 2 additions & 4 deletions core/trino-main/src/main/java/io/trino/sql/ir/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
@Immutable
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME)
@JsonSubTypes({
@JsonSubTypes.Type(value = Arithmetic.class, name = "arithmetic"),
@JsonSubTypes.Type(value = Negation.class, name = "negation"),
@JsonSubTypes.Type(value = Between.class, name = "between"),
@JsonSubTypes.Type(value = Bind.class, name = "bind"),
@JsonSubTypes.Type(value = Cast.class, name = "cast"),
Expand All @@ -46,9 +44,9 @@
@JsonSubTypes.Type(value = Reference.class, name = "reference"),
})
public sealed interface Expression
permits Arithmetic, Between, Bind, Call, Case, Cast, Coalesce,
permits Between, Bind, Call, Case, Cast, Coalesce,
Comparison, Constant, FieldReference, In, IsNull, Lambda, Logical,
Negation, Not, NullIf, Reference, Row, Switch
Not, NullIf, Reference, Row, Switch
{
Type type();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,6 @@ protected String visitCoalesce(Coalesce node, Void context)
return "COALESCE(" + joinExpressions(node.operands()) + ")";
}

@Override
protected String visitNegation(Negation node, Void context)
{
return "-(" + process(node.value(), context) + ")";
}

@Override
protected String visitArithmetic(Arithmetic node, Void context)
{
return formatBinaryExpression(node.operator().getValue(), node.left(), node.right());
}

@Override
public String visitCast(Cast node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@ public Expression rewriteRow(Row node, C context, ExpressionTreeRewriter<C> tree
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteNegation(Negation node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteArithmetic(Arithmetic node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteComparison(Comparison node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,44 +97,6 @@ protected Expression visitRow(Row node, Context<C> context)
return node;
}

@Override
protected Expression visitNegation(Negation node, Context<C> context)
{
if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteNegation(node, context.get(), ExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}

Expression child = rewrite(node.value(), context.get());
if (child != node.value()) {
return new Negation(child);
}

return node;
}

@Override
public Expression visitArithmetic(Arithmetic node, Context<C> context)
{
if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteArithmetic(node, context.get(), ExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}

Expression left = rewrite(node.left(), context.get());
Expression right = rewrite(node.right(), context.get());

if (left != node.left() || right != node.right()) {
return new Arithmetic(node.function(), node.operator(), left, right);
}

return node;
}

@Override
protected Expression visitFieldReference(FieldReference node, Context<C> context)
{
Expand Down
10 changes: 0 additions & 10 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ protected R visitExpression(Expression node, C context)
return null;
}

protected R visitArithmetic(Arithmetic node, C context)
{
return visitExpression(node, context);
}

protected R visitBetween(Between node, C context)
{
return visitExpression(node, context);
Expand Down Expand Up @@ -82,11 +77,6 @@ protected R visitNullIf(NullIf node, C context)
return visitExpression(node, context);
}

protected R visitNegation(Negation node, C context)
{
return visitExpression(node, context);
}

protected R visitNot(Not node, C context)
{
return visitExpression(node, context);
Expand Down
Loading