Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.spi.expression.StandardFunctions.COALESCE_FUNCTION_NAME;
import static java.lang.String.join;

public class RewriteCoalesce
implements ConnectorExpressionRule<Call, ParameterizedExpression>
{
private static final Capture<Call> CALL = newCapture();
private final Pattern<Call> pattern;

public RewriteCoalesce()
{
this.pattern = call()
.with(functionName().matching(name -> name.equals(COALESCE_FUNCTION_NAME)))
.capturedAs(CALL);
}

@Override
public Pattern<Call> getPattern()
{
return pattern;
}

@Override
public Optional<ParameterizedExpression> rewrite(Call call, Captures captures, RewriteContext<ParameterizedExpression> context)
{
verify(call.getArguments().size() >= 2, "Function 'coalesce' expects more than or equals to two arguments");

ImmutableList.Builder<String> rewrittenArguments = ImmutableList.builderWithExpectedSize(call.getArguments().size());
ImmutableList.Builder<QueryParameter> parameters = ImmutableList.builder();
for (ConnectorExpression expression : captures.get(CALL).getArguments()) {
Optional<ParameterizedExpression> rewritten = context.defaultRewrite(expression);
if (rewritten.isEmpty()) {
return Optional.empty();
}
rewrittenArguments.add(rewritten.get().expression());
parameters.addAll(rewritten.get().parameters());
}

List<String> arguments = rewrittenArguments.build();
verify(arguments.size() >= 2, "Function 'coalesce' expects more than or equals to two arguments");
return Optional.of(new ParameterizedExpression("COALESCE(%s)".formatted(join(",", arguments)), parameters.build()));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this lazy eval for PostgreSQL too?
We don't seem to have any test coverage for this.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp;
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.expression.RewriteCoalesce;
import io.trino.plugin.jdbc.expression.RewriteIn;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping;
Expand Down Expand Up @@ -162,7 +163,6 @@
import static io.trino.geospatial.serde.JtsGeometrySerde.serialize;
import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse;
import static io.trino.plugin.base.util.JsonTypeUtil.toJsonValue;
import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRoundingMode;
Expand Down Expand Up @@ -334,6 +334,7 @@ public PostgreSqlClient(
this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.add(new RewriteIn())
.add(new RewriteCoalesce())
.withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint"))
.withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double"))
.map("$equal(left, right)").to("left = right")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.trino.spi.function.OperatorType;
import io.trino.spi.session.PropertyMetadata;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
Expand Down Expand Up @@ -85,6 +86,13 @@ public class TestPostgreSqlClient
.setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))
.build();

private static final JdbcColumnHandle BIGINT_COLUMN2 =
JdbcColumnHandle.builder()
.setColumnName("c_bigint2")
.setColumnType(BIGINT)
.setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))
.build();

private static final JdbcColumnHandle DOUBLE_COLUMN =
JdbcColumnHandle.builder()
.setColumnName("c_double")
Expand Down Expand Up @@ -411,6 +419,33 @@ public void testConvertIn()
new QueryParameter(createVarcharType(10), Optional.of(utf8Slice("value2")))));
}

@Test
public void testConvertCoalesce()
{
// COALESCE(varchar, varchar)
ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
translateToConnectorExpression(
new Coalesce(
new Reference(VARCHAR, "c_varchar_symbol"),
new Reference(VARCHAR, "c_varchar_symbol_2"))),
Map.of("c_varchar_symbol", VARCHAR_COLUMN, "c_varchar_symbol_2", VARCHAR_COLUMN2))
.orElseThrow();
assertThat(converted.expression()).isEqualTo("COALESCE(\"c_varchar\",\"c_varchar2\")");
assertThat(converted.parameters()).isEqualTo(List.of());

// COALESCE(bigint, bigint, bigint)
converted = JDBC_CLIENT.convertPredicate(SESSION,
translateToConnectorExpression(
new Coalesce(
new Reference(BIGINT, "c_bigint_symbol"),
new Reference(BIGINT, "c_bigint_symbol_2"),
new Constant(BIGINT, 123L))),
Map.of("c_bigint_symbol", BIGINT_COLUMN, "c_bigint_symbol_2", BIGINT_COLUMN2))
.orElseThrow();
assertThat(converted.expression()).isEqualTo("COALESCE(\"c_bigint\",\"c_bigint2\",?)");
assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(123L))));
}

private ConnectorExpression translateToConnectorExpression(Expression expression)
{
return ConnectorExpressionTranslator.translate(TEST_SESSION, expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,34 @@ public void testOrPredicatePushdown()
assertThat(query("SELECT * FROM nation WHERE name = NULL OR regionkey = 4")).isFullyPushedDown();
}

@Test
public void testCoalescePredicatePushdown()
{
assertThat(query("SELECT * FROM nation WHERE COALESCE(nationkey, 1) = nationkey"))
.isFullyPushedDown();
assertThat(query("SELECT * FROM nation WHERE COALESCE(nationkey, regionkey, 1) = nationkey"))
.isFullyPushedDown();

try (TestTable table = new TestTable(
getQueryRunner()::execute,
"test_coalesce_predicate_pushdown",
"(a_varchar varchar, b_varchar varchar, c_varchar varchar)",
List.of(
"NULL, NULL, 'third not null'",
"'1', '2', 'first and second not null'",
"NULL, '2', 'second not null'"))) {
assertThat(query("SELECT c_varchar FROM " + table.getName() + " WHERE COALESCE(a_varchar, b_varchar) = '1'"))
.matches("VALUES VARCHAR 'first and second not null'")
.isFullyPushedDown();
assertThat(query("SELECT c_varchar FROM " + table.getName() + " WHERE COALESCE(a_varchar, b_varchar) = '2'"))
.matches("VALUES VARCHAR 'second not null'")
.isFullyPushedDown();
assertThat(query("SELECT c_varchar FROM " + table.getName() + " WHERE COALESCE(a_varchar, b_varchar, c_varchar) = 'third not null'"))
.matches("VALUES VARCHAR 'third not null'")
.isFullyPushedDown();
}
}

@Test
public void testLikePredicatePushdown()
{
Expand Down