Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/trino-plugin-toolkit/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
</properties>

<dependencies>
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-matching</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>bootstrap</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;
package io.trino.plugin.base.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;
package io.trino.plugin.base.expression;

import com.google.common.collect.ImmutableSet;
import io.trino.matching.Match;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.expression.AggregateFunctionRule.RewriteContext;
import io.trino.plugin.base.expression.AggregateFunctionRule.RewriteContext;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
Expand All @@ -29,18 +28,18 @@

import static java.util.Objects.requireNonNull;

public final class AggregateFunctionRewriter
public final class AggregateFunctionRewriter<Result>
{
private final Function<String, String> identifierQuote;
private final Set<AggregateFunctionRule> rules;
private final Set<AggregateFunctionRule<Result>> rules;

public AggregateFunctionRewriter(Function<String, String> identifierQuote, Set<AggregateFunctionRule> rules)
public AggregateFunctionRewriter(Function<String, String> identifierQuote, Set<AggregateFunctionRule<Result>> rules)
{
this.identifierQuote = requireNonNull(identifierQuote, "identifierQuote is null");
this.rules = ImmutableSet.copyOf(requireNonNull(rules, "rules is null"));
}

public Optional<JdbcExpression> rewrite(ConnectorSession session, AggregateFunction aggregateFunction, Map<String, ColumnHandle> assignments)
public Optional<Result> rewrite(ConnectorSession session, AggregateFunction aggregateFunction, Map<String, ColumnHandle> assignments)
{
requireNonNull(aggregateFunction, "aggregateFunction is null");
requireNonNull(assignments, "assignments is null");
Expand All @@ -66,11 +65,11 @@ public ConnectorSession getSession()
}
};

for (AggregateFunctionRule rule : rules) {
for (AggregateFunctionRule<Result> rule : rules) {
Iterator<Match> matches = rule.getPattern().match(aggregateFunction, context).iterator();
while (matches.hasNext()) {
Match match = matches.next();
Optional<JdbcExpression> rewritten = rule.rewrite(aggregateFunction, match.captures(), context);
Optional<Result> rewritten = rule.rewrite(aggregateFunction, match.captures(), context);
if (rewritten.isPresent()) {
return rewritten;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;
package io.trino.plugin.base.expression;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
Expand All @@ -27,11 +26,11 @@
import static com.google.common.base.Verify.verifyNotNull;
import static java.util.Objects.requireNonNull;

public interface AggregateFunctionRule
public interface AggregateFunctionRule<Result>
{
Pattern<AggregateFunction> getPattern();

Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context);
Optional<Result> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context);

interface RewriteContext
Comment thread
hashhar marked this conversation as resolved.
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -26,18 +27,18 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static java.lang.String.format;

/**
* Implements {@code avg(decimal(p, s)}
*/
public class ImplementAvgDecimal
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -25,11 +26,11 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;
Expand All @@ -38,7 +39,7 @@
* Implements {@code avg(float)}
*/
public class ImplementAvgFloatingPoint
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -26,17 +27,17 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variables;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variables;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;

public class ImplementCorr
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<List<Variable>> INPUTS = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
Expand All @@ -28,10 +29,10 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand All @@ -40,7 +41,7 @@
* Implements {@code count(x)}.
*/
public class ImplementCount
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
Expand All @@ -25,17 +26,17 @@
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.inputs;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.util.Objects.requireNonNull;

/**
* Implements {@code count(*)}.
*/
public class ImplementCountAll
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private final JdbcTypeHandle bigintTypeHandle;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -26,17 +27,17 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variables;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variables;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;

public class ImplementCovariancePop
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<List<Variable>> INPUTS = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -26,17 +27,17 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variables;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variables;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;

public class ImplementCovarianceSamp
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<List<Variable>> INPUTS = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -26,17 +27,17 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static java.lang.String.format;

/**
* Implements {@code min(x)}, {@code max(x)}.
*/
public class ImplementMinMax
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -26,17 +27,17 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variables;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variables;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;

public class ImplementRegrIntercept
implements AggregateFunctionRule
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<List<Variable>> INPUTS = newCapture();

Expand Down
Loading