Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
Expand Down Expand Up @@ -82,6 +83,7 @@
import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IF_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME;
Expand Down Expand Up @@ -259,6 +261,16 @@ protected Optional<Expression> translateCall(Call call)
return Optional.empty();
}
}
if (IF_FUNCTION_NAME.equals(call.getFunctionName())) {
switch (call.getArguments().size()) {
case 2:
return translateIf(call.getArguments().get(0), call.getArguments().get(1), Optional.empty());
case 3:
return translateIf(call.getArguments().get(0), call.getArguments().get(1), Optional.of(call.getArguments().get(2)));
default:
return Optional.empty();
}
}

QualifiedName name = QualifiedName.of(call.getFunctionName().getName());
List<TypeSignature> argumentTypes = call.getArguments().stream()
Expand Down Expand Up @@ -410,6 +422,23 @@ protected Optional<Expression> translateLike(ConnectorExpression value, Connecto

return Optional.empty();
}

private Optional<Expression> translateIf(ConnectorExpression condition, ConnectorExpression trueValue, Optional<ConnectorExpression> falseValue)
{
Optional<Expression> conditionExpression = translate(condition);
Optional<Expression> trueExpression = translate(trueValue);

if (conditionExpression.isPresent() && trueExpression.isPresent()) {
if (falseValue.isPresent()) {
Optional<Expression> falseExpression = translate(falseValue.get());
return falseExpression.map(expression -> new IfExpression(conditionExpression.get(), trueExpression.get(), expression));
}

return Optional.of(new IfExpression(conditionExpression.get(), trueExpression.get(), null));
}

return Optional.empty();
}
}

public static class SqlToConnectorExpressionTranslator
Expand Down Expand Up @@ -678,6 +707,24 @@ protected Optional<ConnectorExpression> visitNullIfExpression(NullIfExpression n
return Optional.empty();
}

@Override
protected Optional<ConnectorExpression> visitIfExpression(IfExpression node, Void context)
{
Optional<ConnectorExpression> condition = process(node.getCondition());
Optional<ConnectorExpression> trueValue = process(node.getTrueValue());
if (condition.isPresent() && trueValue.isPresent()) {
if (node.getFalseValue().isEmpty()) {
return Optional.of(new Call(typeOf(node), IF_FUNCTION_NAME, ImmutableList.of(condition.get(), trueValue.get())));
}

Optional<ConnectorExpression> falseValue = process(node.getFalseValue().get());
if (falseValue.isPresent()) {
return Optional.of(new Call(typeOf(node), IF_FUNCTION_NAME, ImmutableList.of(condition.get(), trueValue.get(), falseValue.get())));
}
}
return Optional.empty();
}

@Override
protected Optional<ConnectorExpression> visitSubscriptExpression(SubscriptExpression node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
Expand All @@ -52,6 +53,7 @@

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.expression.StandardFunctions.IF_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME;
Expand Down Expand Up @@ -181,6 +183,31 @@ public void testTranslateComparisonExpression(ComparisonExpression.Operator oper
List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE))));
}

@Test(dataProvider = "testTranslateComparisonExpressionDataProvider")
public void testTranslateIf(ComparisonExpression.Operator operator)
{
assertTranslationRoundTrips(
new IfExpression(
new ComparisonExpression(operator, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")),
new SymbolReference("double_symbol_1"),
new SymbolReference("double_symbol_2")),
new Call(DOUBLE,
IF_FUNCTION_NAME,
List.of(new Call(BOOLEAN, ConnectorExpressionTranslator.functionNameForComparisonOperator(operator), List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE))),
new Variable("double_symbol_1", DOUBLE),
new Variable("double_symbol_2", DOUBLE))));

assertTranslationRoundTrips(
new IfExpression(
new ComparisonExpression(operator, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")),
new SymbolReference("double_symbol_1"),
null),
new Call(DOUBLE,
IF_FUNCTION_NAME,
List.of(new Call(BOOLEAN, ConnectorExpressionTranslator.functionNameForComparisonOperator(operator), List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE))),
new Variable("double_symbol_1", DOUBLE))));
}

@DataProvider
public static Object[][] testTranslateComparisonExpressionDataProvider()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ private StandardFunctions() {}
public static final FunctionName GREATER_THAN_OPERATOR_FUNCTION_NAME = new FunctionName("$greater_than");
public static final FunctionName GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$greater_than_or_equal");
public static final FunctionName IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME = new FunctionName("$is_distinct_from");
/**
* $if is a function accepting 2 arguments - condition and trueValue, or 3 arguments - condition, trueValue and falseValue.
* Evaluates and returns true_value if condition is true, otherwise evaluates and returns false_value.
*/
public static final FunctionName IF_FUNCTION_NAME = new FunctionName("$if");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martint and I discussed this and the conclusion was to model IF via CASE.
see #11699 (comment) for more details.


/**
* Arithmetic addition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ public PostgreSqlClient(
.map("$not(value: boolean)").to("NOT value")
.map("$is_null(value)").to("value IS NULL")
.map("$nullif(first, second)").to("NULLIF(first, second)")
.map("$if(condition, trueValue)").to("CASE WHEN condition THEN trueValue END")
.map("$if(condition, trueValue, falseValue)").to("CASE WHEN condition THEN trueValue ELSE falseValue END")
.build();

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
Expand Down Expand Up @@ -248,6 +249,56 @@ public void testConvertComparison(ComparisonExpression.Operator operator)
throw new UnsupportedOperationException("Unsupported operator: " + operator);
}

@Test(dataProvider = "testConvertComparisonDataProvider")
public void testConvertIf(ComparisonExpression.Operator operator)
Comment on lines +252 to +253
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the <test-name>DataProvider is intended to be used by <test-name> test method only

why do you need it here?
just use equality

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it to catch future changes. I will just use equality.

{
// if(c_bigint_symbol = 42, a_varchar, b_varchar)
Optional<String> convertedWithFalseArgument = JDBC_CLIENT.convertPredicate(
SESSION,
translateToConnectorExpression(
new IfExpression(
new ComparisonExpression(
operator,
new SymbolReference("c_bigint_symbol"),
LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)),
new SymbolReference("a_varchar_symbol"),
new SymbolReference("b_varchar_symbol")),
ImmutableMap.of("c_bigint_symbol", BIGINT, "a_varchar_symbol", VARCHAR_COLUMN.getColumnType(), "b_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
ImmutableMap.of("c_bigint_symbol", DOUBLE_COLUMN, "a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN));

// if(c_bigint_symbol = 42, a_varchar)
Optional<String> convertedWithoutFalseArgument = JDBC_CLIENT.convertPredicate(
SESSION,
translateToConnectorExpression(
new IfExpression(
new ComparisonExpression(
operator,
new SymbolReference("c_bigint_symbol"),
LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)),
new SymbolReference("a_varchar_symbol"),
null),
ImmutableMap.of("c_bigint_symbol", BIGINT, "a_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
ImmutableMap.of("c_bigint_symbol", DOUBLE_COLUMN, "a_varchar_symbol", VARCHAR_COLUMN));

switch (operator) {
case EQUAL:
case NOT_EQUAL:
assertThat(convertedWithFalseArgument).hasValue(format("CASE WHEN ((\"c_double\") %s (42)) THEN (\"c_varchar\") ELSE (\"c_varchar\") END", operator.getValue()));
assertThat(convertedWithoutFalseArgument).hasValue(format("CASE WHEN ((\"c_double\") %s (42)) THEN (\"c_varchar\") END", operator.getValue()));
return;
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
case IS_DISTINCT_FROM:
// Not supported yet, even for bigint
assertThat(convertedWithFalseArgument).isEmpty();
assertThat(convertedWithoutFalseArgument).isEmpty();
return;
}
throw new UnsupportedOperationException("Unsupported operator: " + operator);
}

@DataProvider
public static Object[][] testConvertComparisonDataProvider()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,26 @@ public void testNotExpressionPushdown()
}
}

@Test
public void testIfPredicatePushdown()
{
assertThat(query("SELECT nationkey FROM nation WHERE IF(name = 'ALGERIA', true, false)"))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IF(condition, true, false) can be simplified by the engine to condition.

IF(name = 'ALGERIA', regionkey, nationkey) = <regionekey of Algeria>

.matches("VALUES BIGINT '0'")
.isFullyPushedDown();

assertThat(query("SELECT name FROM nation WHERE IF(nationkey = 0, true, false)"))
.matches("VALUES CAST('ALGERIA' AS varchar(25))")
.isFullyPushedDown();

assertThat(query("SELECT name FROM nation WHERE IF(nationkey <> 0, true, false)"))
.matches("SELECT name FROM nation WHERE nationkey <> 0")
.isFullyPushedDown();

assertThat(query("SELECT nationkey FROM nation WHERE IF(name = 'Algeria', true, false)"))
.returnsEmptyResult()
.isFullyPushedDown();
}

@Override
protected String errorMessageForInsertIntoNotNullColumn(String columnName)
{
Expand Down