diff --git a/docs/src/main/sphinx/connector/pinot.rst b/docs/src/main/sphinx/connector/pinot.rst index 52c9ff5943d8..fcf3fa93e442 100644 --- a/docs/src/main/sphinx/connector/pinot.rst +++ b/docs/src/main/sphinx/connector/pinot.rst @@ -10,7 +10,7 @@ Requirements To connect to Pinot, you need: -* Pinot 0.1.0 or higher. +* Pinot 0.8.0 or higher. * Network access from the Trino coordinator and workers to the Pinot controller nodes. Port 8098 is the default port. diff --git a/plugin/trino-pinot/pom.xml b/plugin/trino-pinot/pom.xml index a07a55e6d7d1..b28d6edd8806 100755 --- a/plugin/trino-pinot/pom.xml +++ b/plugin/trino-pinot/pom.xml @@ -14,7 +14,7 @@ ${project.parent.basedir} - 0.6.0 + 0.8.0 @@ -103,12 +103,6 @@ guice - - com.yammer.metrics - metrics-core - 2.2.0 - - commons-codec commons-codec @@ -325,10 +319,18 @@ com.fasterxml.jackson.core jackson-annotations + + jakarta.ws.rs + jakarta.ws.rs-api + javax.validation validation-api + + org.glassfish.hk2.external + jakarta.inject + org.apache.lucene lucene-analyzers-common @@ -340,6 +342,28 @@ + + org.apache.pinot + pinot-segment-local + ${dep.pinot.version} + + + org.apache.lucene + lucene-analyzers-common + + + org.apache.lucene + lucene-core + + + + + + org.apache.pinot + pinot-segment-spi + ${dep.pinot.version} + + org.apache.pinot pinot-spi @@ -382,6 +406,13 @@ runtime + + org.apache.pinot + pinot-yammer + ${dep.pinot.version} + runtime + + io.trino diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java index 5fc1e8f46b5b..e422bae25a21 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java @@ -17,7 +17,7 @@ import io.trino.plugin.pinot.client.PinotClient.BrokerResultRow; import io.trino.plugin.pinot.decoders.Decoder; import io.trino.plugin.pinot.decoders.DecoderFactory; -import io.trino.plugin.pinot.query.PinotQuery; +import io.trino.plugin.pinot.query.PinotQueryInfo; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -39,7 +39,7 @@ public class PinotBrokerPageSource implements ConnectorPageSource { - private final PinotQuery query; + private final PinotQueryInfo query; private final PinotClient pinotClient; private final ConnectorSession session; private final List columnHandles; @@ -56,7 +56,7 @@ public class PinotBrokerPageSource public PinotBrokerPageSource( ConnectorSession session, - PinotQuery query, + PinotQueryInfo query, List columnHandles, PinotClient pinotClient, int limitForBrokerQueries) diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumn.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumn.java deleted file mode 100755 index e63dcbe493ec..000000000000 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumn.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.pinot; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.BigintType; -import io.trino.spi.type.BooleanType; -import io.trino.spi.type.DoubleType; -import io.trino.spi.type.IntegerType; -import io.trino.spi.type.RealType; -import io.trino.spi.type.Type; -import io.trino.spi.type.VarbinaryType; -import io.trino.spi.type.VarcharType; -import org.apache.pinot.spi.data.FieldSpec; -import org.apache.pinot.spi.data.FieldSpec.DataType; -import org.apache.pinot.spi.data.Schema; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.plugin.pinot.PinotErrorCode.PINOT_UNSUPPORTED_COLUMN_TYPE; -import static java.util.Objects.requireNonNull; - -public class PinotColumn -{ - private final String name; - private final Type type; - - @JsonCreator - public PinotColumn( - @JsonProperty("name") String name, - @JsonProperty("type") Type type) - { - checkArgument(!isNullOrEmpty(name), "name is null or is empty"); - this.name = name; - this.type = requireNonNull(type, "type is null"); - } - - @JsonProperty - public String getName() - { - return name; - } - - @JsonProperty - public Type getType() - { - return type; - } - - @Override - public int hashCode() - { - return Objects.hash(name, type); - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - PinotColumn other = (PinotColumn) obj; - return Objects.equals(this.name, other.name) && Objects.equals(this.type, other.type); - } - - @Override - public String toString() - { - return name + ":" + type; - } - - public static List getPinotColumnsForPinotSchema(Schema pinotTableSchema) - { - return pinotTableSchema.getColumnNames().stream() - .filter(columnName -> !columnName.startsWith("$")) // Hidden columns starts with "$", ignore them as we can't use them in PQL - .map(columnName -> new PinotColumn(columnName, getTrinoTypeFromPinotType(pinotTableSchema.getFieldSpecFor(columnName)))) - .collect(toImmutableList()); - } - - public static Type getTrinoTypeFromPinotType(FieldSpec field) - { - Type type = getTrinoTypeFromPinotType(field.getDataType()); - if (field.isSingleValueField()) { - return type; - } - else { - return new ArrayType(type); - } - } - - public static Type getTrinoTypeFromPinotType(DataType dataType) - { - switch (dataType) { - case BOOLEAN: - return BooleanType.BOOLEAN; - case FLOAT: - return RealType.REAL; - case DOUBLE: - return DoubleType.DOUBLE; - case INT: - return IntegerType.INTEGER; - case LONG: - return BigintType.BIGINT; - case STRING: - return VarcharType.VARCHAR; - case BYTES: - return VarbinaryType.VARBINARY; - default: - break; - } - throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "Unsupported type conversion for pinot data type: " + dataType); - } -} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumnHandle.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumnHandle.java index ba7bb18a36ba..7b70ca43b3f8 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumnHandle.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotColumnHandle.java @@ -17,11 +17,28 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.RealType; import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.apache.pinot.core.operator.transform.TransformResultMetadata; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import java.util.List; import java.util.Objects; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.pinot.PinotErrorCode.PINOT_UNSUPPORTED_COLUMN_TYPE; +import static io.trino.plugin.pinot.query.DynamicTablePqlExtractor.quoteIdentifier; import static java.util.Objects.requireNonNull; public class PinotColumnHandle @@ -29,22 +46,97 @@ public class PinotColumnHandle { private final String columnName; private final Type dataType; + private final String expression; + private final boolean aliased; + private final boolean aggregate; private final boolean returnNullOnEmptyGroup; + private final Optional pushedDownAggregateFunctionName; + private final Optional pushedDownAggregateFunctionArgument; public PinotColumnHandle(String columnName, Type dataType) { - this(columnName, dataType, true); + this(columnName, dataType, columnName, false, false, true, Optional.empty(), Optional.empty()); } @JsonCreator public PinotColumnHandle( @JsonProperty("columnName") String columnName, @JsonProperty("dataType") Type dataType, - @JsonProperty("returnNullOnEmptyGroup") boolean returnNullOnEmptyGroup) + @JsonProperty("expression") String expression, + @JsonProperty("aliased") boolean aliased, + @JsonProperty("aggregate") boolean aggregate, + @JsonProperty("returnNullOnEmptyGroup") boolean returnNullOnEmptyGroup, + @JsonProperty("pushedDownAggregateFunctionName") Optional pushedDownAggregateFunctionName, + @JsonProperty("pushedDownAggregateFunctionArgument") Optional pushedDownAggregateFunctionArgument) { this.columnName = requireNonNull(columnName, "columnName is null"); this.dataType = requireNonNull(dataType, "dataType is null"); + this.expression = requireNonNull(expression, "expression is null"); + this.aliased = aliased; + this.aggregate = aggregate; this.returnNullOnEmptyGroup = returnNullOnEmptyGroup; + requireNonNull(pushedDownAggregateFunctionName, "pushedDownaAggregateFunctionName is null"); + requireNonNull(pushedDownAggregateFunctionArgument, "pushedDownaAggregateFunctionArgument is null"); + checkState(pushedDownAggregateFunctionName.isPresent() == pushedDownAggregateFunctionArgument.isPresent(), "Unexpected arguments: Either pushedDownaAggregateFunctionName and pushedDownaAggregateFunctionArgument must both be present or both be empty."); + checkState((pushedDownAggregateFunctionName.isPresent() && aggregate) || pushedDownAggregateFunctionName.isEmpty(), "Unexpected arguments: aggregate is false but pushed down aggregation is present"); + this.pushedDownAggregateFunctionName = pushedDownAggregateFunctionName; + this.pushedDownAggregateFunctionArgument = pushedDownAggregateFunctionArgument; + } + + public static PinotColumnHandle fromNonAggregateColumnHandle(PinotColumnHandle columnHandle) + { + return new PinotColumnHandle(columnHandle.getColumnName(), columnHandle.getDataType(), quoteIdentifier(columnHandle.getColumnName()), false, false, true, Optional.empty(), Optional.empty()); + } + + public static List getPinotColumnsForPinotSchema(Schema pinotTableSchema) + { + return pinotTableSchema.getColumnNames().stream() + .filter(columnName -> !columnName.startsWith("$")) // Hidden columns starts with "$", ignore them as we can't use them in PQL + .map(columnName -> new PinotColumnHandle(columnName, getTrinoTypeFromPinotType(pinotTableSchema.getFieldSpecFor(columnName)))) + .collect(toImmutableList()); + } + + public static Type getTrinoTypeFromPinotType(FieldSpec field) + { + Type type = getTrinoTypeFromPinotType(field.getDataType()); + if (field.isSingleValueField()) { + return type; + } + else { + return new ArrayType(type); + } + } + + public static Type getTrinoTypeFromPinotType(TransformResultMetadata transformResultMetadata) + { + Type type = getTrinoTypeFromPinotType(transformResultMetadata.getDataType()); + if (transformResultMetadata.isSingleValue()) { + return type; + } + return new ArrayType(type); + } + + public static Type getTrinoTypeFromPinotType(FieldSpec.DataType dataType) + { + switch (dataType) { + case BOOLEAN: + return BooleanType.BOOLEAN; + case FLOAT: + return RealType.REAL; + case DOUBLE: + return DoubleType.DOUBLE; + case INT: + return IntegerType.INTEGER; + case LONG: + return BigintType.BIGINT; + case STRING: + return VarcharType.VARCHAR; + case BYTES: + return VarbinaryType.VARBINARY; + default: + break; + } + throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "Unsupported type conversion for pinot data type: " + dataType); } @JsonProperty @@ -59,9 +151,26 @@ public Type getDataType() return dataType; } - public ColumnMetadata getColumnMetadata() + @JsonProperty + public String getExpression() { - return new ColumnMetadata(getColumnName(), getDataType()); + return expression; + } + + // Keep track of whether this column is aliased, it will determine how the pinot sql query is built + // The reason is that pinot parses the broker request into pinot pql but expects pinot sql. + // In some cases the parsed pql expression is an invalid sql expression. + @JsonProperty + public boolean isAliased() + { + return aliased; + } + + // True if this is an aggregate column for both passthrough query and pushed down aggregate expressions. + @JsonProperty + public boolean isAggregate() + { + return aggregate; } // Some aggregations should return null on empty group, ex. min/max @@ -72,6 +181,32 @@ public boolean isReturnNullOnEmptyGroup() return returnNullOnEmptyGroup; } + // If the aggregate expression is pushed down store the function name + // If the argument is an alias the pinot expression will use the original + // column name in the expression and alias it. + // + // Example: SELECT MAX(bar) FROM "SELECT foo AS bar FROM table" + // Will translate to the pinot query "SELECT MAX(foo) AS \"max(bar)\"" + // + // Note: Pinot omits quotes on the autogenerated column name "max(bar)" + @JsonProperty + public Optional getPushedDownAggregateFunctionName() + { + return pushedDownAggregateFunctionName; + } + + // See comment for getPushedDownaAggregateFunctionName() + @JsonProperty + public Optional getPushedDownAggregateFunctionArgument() + { + return pushedDownAggregateFunctionArgument; + } + + public ColumnMetadata getColumnMetadata() + { + return new ColumnMetadata(getColumnName(), getDataType()); + } + @Override public boolean equals(Object o) { @@ -99,7 +234,12 @@ public String toString() return toStringHelper(this) .add("columnName", columnName) .add("dataType", dataType) + .add("expression", expression) + .add("aliased", aliased) + .add("aggregate", aggregate) .add("returnNullOnEmptyGroup", returnNullOnEmptyGroup) + .add("pushedDownaAggregateFunctionName", pushedDownAggregateFunctionName) + .add("pushedDownaAggregateFunctionArgument", pushedDownAggregateFunctionArgument) .toString(); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotInsufficientServerResponseException.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotInsufficientServerResponseException.java index 3a2888d8b1c5..0823fef68fd6 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotInsufficientServerResponseException.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotInsufficientServerResponseException.java @@ -13,7 +13,7 @@ */ package io.trino.plugin.pinot; -import io.trino.plugin.pinot.query.PinotQuery; +import io.trino.plugin.pinot.query.PinotQueryInfo; import java.util.Optional; @@ -23,12 +23,12 @@ public class PinotInsufficientServerResponseException extends PinotException { - public PinotInsufficientServerResponseException(PinotQuery query, int numberOfServersResponded, int numberOfServersQueried) + public PinotInsufficientServerResponseException(PinotQueryInfo query, int numberOfServersResponded, int numberOfServersQueried) { this(query, format("Only %s out of %s servers responded for query %s", numberOfServersResponded, numberOfServersQueried, query.getQuery())); } - public PinotInsufficientServerResponseException(PinotQuery query, String message) + public PinotInsufficientServerResponseException(PinotQueryInfo query, String message) { super(PINOT_INSUFFICIENT_SERVER_RESPONSE, Optional.of(query.getQuery()), message, true); } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java index cac8641167a0..3e92b54bfb4c 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java @@ -24,6 +24,7 @@ import io.trino.plugin.base.expression.AggregateFunctionRewriter; import io.trino.plugin.base.expression.AggregateFunctionRule; import io.trino.plugin.pinot.client.PinotClient; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.plugin.pinot.query.DynamicTable; import io.trino.plugin.pinot.query.DynamicTableBuilder; import io.trino.plugin.pinot.query.aggregation.ImplementApproxDistinct; @@ -37,7 +38,6 @@ import io.trino.spi.connector.Assignment; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; -import io.trino.spi.connector.ColumnNotFoundException; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; @@ -72,9 +72,11 @@ import static com.google.common.base.Verify.verify; import static com.google.common.cache.CacheLoader.asyncReloading; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.plugin.pinot.PinotColumn.getPinotColumnsForPinotSchema; +import static io.trino.plugin.pinot.PinotColumnHandle.getPinotColumnsForPinotSchema; import static io.trino.plugin.pinot.PinotSessionProperties.isAggregationPushdownEnabled; +import static io.trino.plugin.pinot.query.AggregateExpression.replaceIdentifier; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.function.UnaryOperator.identity; @@ -88,11 +90,12 @@ public class PinotMetadata private static final String SCHEMA_NAME = "default"; private static final String PINOT_COLUMN_NAME_PROPERTY = "pinotColumnName"; - private final LoadingCache> pinotTableColumnCache; + private final LoadingCache> pinotTableColumnCache; private final LoadingCache> allTablesCache; private final int maxRowsPerBrokerQuery; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; private final ImplementCountDistinct implementCountDistinct; + private final PinotClient pinotClient; @Inject public PinotMetadata( @@ -101,6 +104,7 @@ public PinotMetadata( @ForPinot Executor executor) { requireNonNull(pinotConfig, "pinot config"); + this.pinotClient = requireNonNull(pinotClient, "pinotClient is null"); long metadataCacheExpiryMillis = pinotConfig.getMetadataCacheExpiry().roundTo(TimeUnit.MILLISECONDS); this.allTablesCache = CacheBuilder.newBuilder() .refreshAfterWrite(metadataCacheExpiryMillis, TimeUnit.MILLISECONDS) @@ -111,7 +115,7 @@ public PinotMetadata( .build(asyncReloading(new CacheLoader<>() { @Override - public List load(String tableName) + public List load(String tableName) throws Exception { Schema tablePinotSchema = pinotClient.getTableSchema(tableName); @@ -142,7 +146,7 @@ public List listSchemaNames(ConnectorSession session) public PinotTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) { if (tableName.getTableName().trim().startsWith("select ")) { - DynamicTable dynamicTable = DynamicTableBuilder.buildFromPql(this, tableName); + DynamicTable dynamicTable = DynamicTableBuilder.buildFromPql(this, tableName, pinotClient); return new PinotTableHandle(tableName.getSchemaName(), dynamicTable.getTableName(), TupleDomain.all(), OptionalLong.empty(), Optional.of(dynamicTable)); } @@ -162,18 +166,12 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect PinotTableHandle pinotTableHandle = (PinotTableHandle) table; if (pinotTableHandle.getQuery().isPresent()) { DynamicTable dynamicTable = pinotTableHandle.getQuery().get(); - Map columnHandles = getColumnHandles(session, table); ImmutableList.Builder columnMetadataBuilder = ImmutableList.builder(); - for (String columnName : dynamicTable.getSelections()) { - PinotColumnHandle pinotColumnHandle = (PinotColumnHandle) columnHandles.get(columnName.toLowerCase(ENGLISH)); + for (PinotColumnHandle pinotColumnHandle : dynamicTable.getProjections()) { columnMetadataBuilder.add(pinotColumnHandle.getColumnMetadata()); } - - for (String columnName : dynamicTable.getGroupingColumns()) { - PinotColumnHandle pinotColumnHandle = (PinotColumnHandle) columnHandles.get(columnName.toLowerCase(ENGLISH)); - columnMetadataBuilder.add(pinotColumnHandle.getColumnMetadata()); - } - dynamicTable.getAggregateColumns().forEach(handle -> columnMetadataBuilder.add(handle.getColumnMetadata())); + dynamicTable.getAggregateColumns() + .forEach(columnHandle -> columnMetadataBuilder.add(columnHandle.getColumnMetadata())); SchemaTableName schemaTableName = new SchemaTableName(pinotTableHandle.getSchemaName(), dynamicTable.getTableName()); return new ConnectorTableMetadata(schemaTableName, columnMetadataBuilder.build()); } @@ -266,7 +264,7 @@ public Optional> applyLimit(Connect (dynamicTable.get().getLimit().isEmpty() || dynamicTable.get().getLimit().getAsLong() > limit)) { dynamicTable = Optional.of(new DynamicTable(dynamicTable.get().getTableName(), dynamicTable.get().getSuffix(), - dynamicTable.get().getSelections(), + dynamicTable.get().getProjections(), dynamicTable.get().getFilter(), dynamicTable.get().getGroupingColumns(), dynamicTable.get().getAggregateColumns(), @@ -349,31 +347,35 @@ public Optional> applyAggrega } PinotTableHandle tableHandle = (PinotTableHandle) handle; - // If aggregate are present than no further aggregations - // can be pushed down: there are currently no subqueries in pinot + // If aggregates are present than no further aggregations + // can be pushed down: there are currently no subqueries in pinot. + // If there is an offset then do not push the aggregation down as the results will not be correct if (tableHandle.getQuery().isPresent() && - !tableHandle.getQuery().get().getAggregateColumns().isEmpty()) { + (!tableHandle.getQuery().get().getAggregateColumns().isEmpty() || + tableHandle.getQuery().get().isAggregateInProjections() || + tableHandle.getQuery().get().getOffset().isPresent())) { return Optional.empty(); } ImmutableList.Builder projections = ImmutableList.builder(); ImmutableList.Builder resultAssignments = ImmutableList.builder(); - ImmutableList.Builder aggregationExpressions = ImmutableList.builder(); + ImmutableList.Builder aggregateColumnsBuilder = ImmutableList.builder(); for (AggregateFunction aggregate : aggregates) { - Optional rewriteResult = aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + Optional rewriteResult = aggregateFunctionRewriter.rewrite(session, aggregate, assignments); rewriteResult = applyCountDistinct(session, aggregate, assignments, tableHandle, rewriteResult); if (rewriteResult.isEmpty()) { return Optional.empty(); } - PinotColumnHandle pinotColumnHandle = rewriteResult.get(); - aggregationExpressions.add(pinotColumnHandle); + AggregateExpression aggregateExpression = rewriteResult.get(); + PinotColumnHandle pinotColumnHandle = new PinotColumnHandle(aggregateExpression.toFieldName(), aggregate.getOutputType(), aggregateExpression.toExpression(), false, true, aggregateExpression.isReturnNullOnEmptyGroup(), Optional.of(aggregateExpression.getFunction()), Optional.of(aggregateExpression.getArgument())); + aggregateColumnsBuilder.add(pinotColumnHandle); projections.add(new Variable(pinotColumnHandle.getColumnName(), pinotColumnHandle.getDataType())); resultAssignments.add(new Assignment(pinotColumnHandle.getColumnName(), pinotColumnHandle, pinotColumnHandle.getDataType())); } - List groupingColumns = getOnlyElement(groupingSets).stream() + List groupingColumns = getOnlyElement(groupingSets).stream() .map(PinotColumnHandle.class::cast) - .map(PinotColumnHandle::getColumnName) + .map(PinotColumnHandle::fromNonAggregateColumnHandle) .collect(toImmutableList()); OptionalLong limitForDynamicTable = OptionalLong.empty(); // Ensure that pinot default limit of 10 rows is not used @@ -382,23 +384,43 @@ public Optional> applyAggrega if (tableHandle.getLimit().isEmpty() && !groupingColumns.isEmpty()) { limitForDynamicTable = OptionalLong.of(maxRowsPerBrokerQuery + 1); } + List aggregationColumns = aggregateColumnsBuilder.build(); + String newQuery = ""; + List newSelections = groupingColumns; + if (tableHandle.getQuery().isPresent()) { + newQuery = tableHandle.getQuery().get().getQuery(); + Map projectionsMap = tableHandle.getQuery().get().getProjections().stream() + .collect(toImmutableMap(PinotColumnHandle::getColumnName, identity())); + groupingColumns = groupingColumns.stream() + .map(groupIngColumn -> projectionsMap.getOrDefault(groupIngColumn.getColumnName(), groupIngColumn)) + .collect(toImmutableList()); + ImmutableList.Builder newSelectionsBuilder = ImmutableList.builder() + .addAll(groupingColumns); + + aggregationColumns = aggregationColumns.stream() + .map(aggregateExpression -> resolveAggregateExpressionWithAlias(aggregateExpression, projectionsMap)) + .collect(toImmutableList()); + + newSelections = newSelectionsBuilder.build(); + } + DynamicTable dynamicTable = new DynamicTable( tableHandle.getTableName(), Optional.empty(), - ImmutableList.of(), + newSelections, tableHandle.getQuery().flatMap(DynamicTable::getFilter), groupingColumns, - aggregationExpressions.build(), + aggregationColumns, ImmutableList.of(), limitForDynamicTable, OptionalLong.empty(), - ""); + newQuery); tableHandle = new PinotTableHandle(tableHandle.getSchemaName(), tableHandle.getTableName(), tableHandle.getConstraint(), tableHandle.getLimit(), Optional.of(dynamicTable)); return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false)); } - private Optional applyCountDistinct(ConnectorSession session, AggregateFunction aggregate, Map assignments, PinotTableHandle tableHandle, Optional rewriteResult) + private Optional applyCountDistinct(ConnectorSession session, AggregateFunction aggregate, Map assignments, PinotTableHandle tableHandle, Optional rewriteResult) { AggregateFunctionRule.RewriteContext context = new AggregateFunctionRule.RewriteContext() { @@ -428,13 +450,36 @@ public ConnectorSession getSession() // otherwise do not push down the aggregation. // This is to avoid count(column_name) being pushed into pinot, which is currently unsupported. // Currently Pinot treats count(column_name) as count(*), i.e. it counts nulls. - if (tableHandle.getQuery().isEmpty() || !tableHandle.getQuery().get().getGroupingColumns().contains(input.getName())) { + if (tableHandle.getQuery().isEmpty() || tableHandle.getQuery().get().getGroupingColumns().stream() + .noneMatch(groupingExpression -> groupingExpression.getColumnName().equals(input.getName()))) { return Optional.empty(); } } return rewriteResult; } + private static PinotColumnHandle resolveAggregateExpressionWithAlias(PinotColumnHandle aggregateColumn, Map projectionsMap) + { + checkState(aggregateColumn.isAggregate() && aggregateColumn.getPushedDownAggregateFunctionName().isPresent() && aggregateColumn.getPushedDownAggregateFunctionArgument().isPresent(), "Column is not a pushed down aggregate column"); + PinotColumnHandle selection = projectionsMap.get(aggregateColumn.getPushedDownAggregateFunctionArgument().get()); + if (selection != null && selection.isAliased()) { + AggregateExpression pushedDownAggregateExpression = new AggregateExpression(aggregateColumn.getPushedDownAggregateFunctionName().get(), + aggregateColumn.getPushedDownAggregateFunctionArgument().get(), + aggregateColumn.isReturnNullOnEmptyGroup()); + AggregateExpression newPushedDownAggregateExpression = replaceIdentifier(pushedDownAggregateExpression, selection); + + return new PinotColumnHandle(pushedDownAggregateExpression.toFieldName(), + aggregateColumn.getDataType(), + newPushedDownAggregateExpression.toExpression(), + true, + aggregateColumn.isAggregate(), + aggregateColumn.isReturnNullOnEmptyGroup(), + aggregateColumn.getPushedDownAggregateFunctionName(), + Optional.of(newPushedDownAggregateExpression.getArgument())); + } + return aggregateColumn; + } + @Override public boolean usesLegacyTableLayouts() { @@ -442,7 +487,7 @@ public boolean usesLegacyTableLayouts() } @VisibleForTesting - public List getPinotColumns(String tableName) + public List getPinotColumns(String tableName) { String pinotTableName = getPinotTableNameFromTrinoTableName(tableName); return getFromCache(pinotTableColumnCache, pinotTableName); @@ -486,28 +531,14 @@ private String getPinotTableNameFromTrinoTableName(String trinoTableName) private Map getDynamicTableColumnHandles(PinotTableHandle pinotTableHandle) { checkState(pinotTableHandle.getQuery().isPresent(), "dynamic table not present"); - String schemaName = pinotTableHandle.getSchemaName(); DynamicTable dynamicTable = pinotTableHandle.getQuery().get(); - Map columnHandles = getPinotColumnHandles(dynamicTable.getTableName()); ImmutableMap.Builder columnHandlesBuilder = ImmutableMap.builder(); - for (String columnName : dynamicTable.getSelections()) { - PinotColumnHandle columnHandle = (PinotColumnHandle) columnHandles.get(columnName.toLowerCase(ENGLISH)); - if (columnHandle == null) { - throw new ColumnNotFoundException(new SchemaTableName(schemaName, dynamicTable.getTableName()), columnName); - } - columnHandlesBuilder.put(columnName.toLowerCase(ENGLISH), columnHandle); - } - - for (String columnName : dynamicTable.getGroupingColumns()) { - PinotColumnHandle columnHandle = (PinotColumnHandle) columnHandles.get(columnName.toLowerCase(ENGLISH)); - if (columnHandle == null) { - throw new ColumnNotFoundException(new SchemaTableName(schemaName, dynamicTable.getTableName()), columnName); - } - columnHandlesBuilder.put(columnName.toLowerCase(ENGLISH), columnHandle); + for (PinotColumnHandle pinotColumnHandle : dynamicTable.getProjections()) { + columnHandlesBuilder.put(pinotColumnHandle.getColumnName().toLowerCase(ENGLISH), pinotColumnHandle); } dynamicTable.getAggregateColumns() - .forEach(handle -> columnHandlesBuilder.put(handle.getColumnName().toLowerCase(ENGLISH), handle)); + .forEach(columnHandle -> columnHandlesBuilder.put(columnHandle.getColumnName().toLowerCase(ENGLISH), columnHandle)); return columnHandlesBuilder.build(); } @@ -518,19 +549,19 @@ private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) private List getColumnsMetadata(String tableName) { - List columns = getPinotColumns(tableName); + List columns = getPinotColumns(tableName); return columns.stream() .map(PinotMetadata::createPinotColumnMetadata) .collect(toImmutableList()); } - private static ColumnMetadata createPinotColumnMetadata(PinotColumn pinotColumn) + private static ColumnMetadata createPinotColumnMetadata(PinotColumnHandle pinotColumn) { return ColumnMetadata.builder() - .setName(pinotColumn.getName().toLowerCase(ENGLISH)) - .setType(pinotColumn.getType()) + .setName(pinotColumn.getColumnName().toLowerCase(ENGLISH)) + .setType(pinotColumn.getDataType()) .setProperties(ImmutableMap.builder() - .put(PINOT_COLUMN_NAME_PROPERTY, pinotColumn.getName()) + .put(PINOT_COLUMN_NAME_PROPERTY, pinotColumn.getColumnName()) .build()) .build(); } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotPageSourceProvider.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotPageSourceProvider.java index 52040646dbf8..d69816bc321b 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotPageSourceProvider.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotPageSourceProvider.java @@ -16,7 +16,7 @@ import io.trino.plugin.pinot.client.PinotClient; import io.trino.plugin.pinot.client.PinotQueryClient; import io.trino.plugin.pinot.query.DynamicTable; -import io.trino.plugin.pinot.query.PinotQuery; +import io.trino.plugin.pinot.query.PinotQueryInfo; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorPageSourceProvider; @@ -89,20 +89,20 @@ public ConnectorPageSource createPageSource( handles, query); case BROKER: - PinotQuery pinotQuery; + PinotQueryInfo pinotQueryInfo; if (pinotTableHandle.getQuery().isPresent()) { DynamicTable dynamicTable = pinotTableHandle.getQuery().get(); - pinotQuery = new PinotQuery(dynamicTable.getTableName(), + pinotQueryInfo = new PinotQueryInfo(dynamicTable.getTableName(), extractPql(dynamicTable, pinotTableHandle.getConstraint(), handles), dynamicTable.getGroupingColumns().size()); } else { - pinotQuery = new PinotQuery(pinotTableHandle.getTableName(), query, 0); + pinotQueryInfo = new PinotQueryInfo(pinotTableHandle.getTableName(), query, 0); } return new PinotBrokerPageSource( session, - pinotQuery, + pinotQueryInfo, handles, clusterInfoFetcher, limitForBrokerQueries); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java index 08f411b4b53e..bde45122c292 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java @@ -325,7 +325,7 @@ Type getType(int columnIndex) boolean getBoolean(int rowIdx, int columnIndex) { - return Boolean.getBoolean(currentDataTable.getDataTable().getString(rowIdx, columnIndex)); + return currentDataTable.getDataTable().getInt(rowIdx, columnIndex) != 0; } long getLong(int rowIndex, int columnIndex) diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java index 3e7ebc69c362..e7f5159e8751 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java @@ -41,7 +41,7 @@ import io.trino.plugin.pinot.PinotException; import io.trino.plugin.pinot.PinotInsufficientServerResponseException; import io.trino.plugin.pinot.PinotSessionProperties; -import io.trino.plugin.pinot.query.PinotQuery; +import io.trino.plugin.pinot.query.PinotQueryInfo; import io.trino.spi.connector.ConnectorSession; import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.common.response.broker.ResultTable; @@ -49,8 +49,6 @@ import javax.inject.Inject; -import java.io.IOException; -import java.io.UncheckedIOException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -73,6 +71,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.json.JsonCodec.listJsonCodec; import static io.airlift.json.JsonCodec.mapJsonCodec; import static io.trino.plugin.pinot.PinotErrorCode.PINOT_EXCEPTION; @@ -91,7 +90,7 @@ public class PinotClient private static final Pattern BROKER_PATTERN = Pattern.compile("Broker_(.*)_(\\d+)"); private static final String TIME_BOUNDARY_NOT_FOUND_ERROR_CODE = "404"; private static final JsonCodec>>> ROUTING_TABLE_CODEC = mapJsonCodec(String.class, mapJsonCodec(String.class, listJsonCodec(String.class))); - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final JsonCodec QUERY_REQUEST_JSON_CODEC = jsonCodec(QueryRequest.class); private static final String GET_ALL_TABLES_API_TEMPLATE = "tables"; private static final String TABLE_INSTANCES_API_TEMPLATE = "tables/%s/instances"; @@ -377,6 +376,23 @@ public TimeBoundary getTimeBoundaryForTable(String table) } } + public static class QueryRequest + { + private final String sql; + + @JsonCreator + public QueryRequest(@JsonProperty String sql) + { + this.sql = requireNonNull(sql, "sql is null"); + } + + @JsonProperty + public String getSql() + { + return sql; + } + } + public interface BrokerResultRow { Object getField(int index); @@ -424,14 +440,15 @@ protected BrokerResultRow computeNext() } } - private BrokerResponseNative submitBrokerQueryJson(ConnectorSession session, PinotQuery query) + private BrokerResponseNative submitBrokerQueryJson(ConnectorSession session, PinotQueryInfo query) { + String queryRequest = QUERY_REQUEST_JSON_CODEC.toJson(new QueryRequest(query.getQuery())); return doWithRetries(PinotSessionProperties.getPinotRetryCount(session), retryNumber -> { String queryHost = getBrokerHost(query.getTable()); LOG.info("Query '%s' on broker host '%s'", queryHost, query.getQuery()); Request.Builder builder = Request.Builder.preparePost() .setUri(URI.create(format(QUERY_URL_TEMPLATE, queryHost))); - BrokerResponseNative response = doHttpActionWithHeadersJson(builder, Optional.of(buildRequest(query.getQuery())), brokerResponseCodec); + BrokerResponseNative response = doHttpActionWithHeadersJson(builder, Optional.of(queryRequest), brokerResponseCodec); if (response.getExceptionsSize() > 0 && response.getProcessingExceptions() != null && !response.getProcessingExceptions().isEmpty()) { // Pinot is known to return exceptions with benign errorcodes like 200 @@ -449,16 +466,6 @@ private BrokerResponseNative submitBrokerQueryJson(ConnectorSession session, Pin }); } - private static String buildRequest(String sql) - { - try { - return OBJECT_MAPPER.writeValueAsString(ImmutableMap.of("sql", sql)); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - /** * columnIndices: column name -> column index from column handles * indiceToGroupByFunction (groupByFunctions): aggregationIndex -> groupByFunctionName(columnName) @@ -472,7 +479,7 @@ private static String buildRequest(String sql) * Results: aggregationResults.get(0..aggregationResults.size()) * Result: function, value means columnName -> columnValue */ - public Iterator createResultIterator(ConnectorSession session, PinotQuery query, List columnHandles) + public Iterator createResultIterator(ConnectorSession session, PinotQueryInfo query, List columnHandles) { BrokerResponseNative response = submitBrokerQueryJson(session, query); return fromResultTable(response, columnHandles, query.getGroupByClauses()); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotQueryClient.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotQueryClient.java index cb5451919fb8..8a13aa57bc9c 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotQueryClient.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotQueryClient.java @@ -13,10 +13,10 @@ */ package io.trino.plugin.pinot.client; -import com.yammer.metrics.core.MetricsRegistry; import io.trino.plugin.pinot.PinotException; import org.apache.helix.model.InstanceConfig; import org.apache.pinot.common.metrics.BrokerMetrics; +import org.apache.pinot.common.metrics.PinotMetricUtils; import org.apache.pinot.common.request.BrokerRequest; import org.apache.pinot.common.utils.DataTable; import org.apache.pinot.core.transport.AsyncQueryResponse; @@ -24,6 +24,7 @@ import org.apache.pinot.core.transport.ServerInstance; import org.apache.pinot.core.transport.ServerResponse; import org.apache.pinot.core.transport.ServerRoutingInstance; +import org.apache.pinot.spi.metrics.PinotMetricsRegistry; import org.apache.pinot.spi.utils.builder.TableNameBuilder; import org.apache.pinot.sql.parsers.CalciteSqlCompiler; import org.apache.pinot.sql.parsers.SqlCompilationException; @@ -62,7 +63,7 @@ public PinotQueryClient(PinotHostMapper pinotHostMapper) { trinoHostId = getDefaultTrinoId(); this.pinotHostMapper = requireNonNull(pinotHostMapper, "pinotHostMapper is null"); - MetricsRegistry registry = new MetricsRegistry(); + PinotMetricsRegistry registry = PinotMetricUtils.getPinotMetricsRegistry(); this.brokerMetrics = new BrokerMetrics(registry); brokerMetrics.initializeGlobalMeters(); queryRouter = new QueryRouter(trinoHostId, brokerMetrics); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/BooleanDecoder.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/BooleanDecoder.java index 9e020a6f6b5f..8e0fb023d6f9 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/BooleanDecoder.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/BooleanDecoder.java @@ -13,12 +13,14 @@ */ package io.trino.plugin.pinot.decoders; +import io.trino.spi.TrinoException; import io.trino.spi.block.BlockBuilder; import java.util.function.Supplier; +import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static java.lang.Boolean.parseBoolean; +import static java.lang.String.format; public class BooleanDecoder implements Decoder @@ -30,8 +32,11 @@ public void decode(Supplier getter, BlockBuilder output) if (value == null) { output.appendNull(); } + else if (value instanceof Boolean) { + BOOLEAN.writeBoolean(output, (Boolean) value); + } else { - BOOLEAN.writeBoolean(output, parseBoolean(value.toString())); + throw new TrinoException(TYPE_MISMATCH, format("Expected a boolean value of type BOOLEAN: %s [%s]", value, value.getClass().getSimpleName())); } } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/AggregateExpression.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/AggregateExpression.java new file mode 100644 index 000000000000..8a9918a77af0 --- /dev/null +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/AggregateExpression.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.pinot.query; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.plugin.pinot.PinotColumnHandle; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static io.trino.plugin.pinot.query.DynamicTablePqlExtractor.quoteIdentifier; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class AggregateExpression +{ + private final String function; + private final String argument; + private final boolean returnNullOnEmptyGroup; + + public static AggregateExpression replaceIdentifier(AggregateExpression aggregationExpression, PinotColumnHandle columnHandle) + { + return new AggregateExpression(aggregationExpression.getFunction(), stripDoubleQuotes(columnHandle.getExpression()), aggregationExpression.isReturnNullOnEmptyGroup()); + } + + private static String stripDoubleQuotes(String expression) + { + checkState(expression.startsWith("\"") && expression.endsWith("\"") && expression.length() >= 3, "expression is not enclosed in double quotes"); + return expression.substring(1, expression.length() - 1).replaceAll("\"\"", "\""); + } + + @JsonCreator + public AggregateExpression(@JsonProperty String function, @JsonProperty String argument, @JsonProperty boolean returnNullOnEmptyGroup) + { + this.function = requireNonNull(function, "function is null"); + this.argument = requireNonNull(argument, "argument is null"); + this.returnNullOnEmptyGroup = returnNullOnEmptyGroup; + } + + @JsonProperty + public String getFunction() + { + return function; + } + + @JsonProperty + public String getArgument() + { + return argument; + } + + @JsonProperty + public boolean isReturnNullOnEmptyGroup() + { + return returnNullOnEmptyGroup; + } + + public String toFieldName() + { + return format("%s(%s)", function, argument); + } + + public String toExpression() + { + return format("%s(%s)", function, quoteIdentifier(argument)); + } + + @Override + public boolean equals(Object other) + { + if (this == other) { + return true; + } + if (!(other instanceof AggregateExpression)) { + return false; + } + AggregateExpression that = (AggregateExpression) other; + return that.function.equals(function) && + that.argument.equals(argument) && + that.returnNullOnEmptyGroup == returnNullOnEmptyGroup; + } + + @Override + public int hashCode() + { + return Objects.hash(function, argument, returnNullOnEmptyGroup); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("function", function) + .add("argument", argument) + .add("returnNullOnEmptyGroup", returnNullOnEmptyGroup) + .toString(); + } +} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTable.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTable.java index 35ebb2b22a62..094eec6575ba 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTable.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTable.java @@ -32,12 +32,12 @@ public final class DynamicTable private final Optional suffix; - private final List selections; + private final List projections; private final Optional filter; // semantically aggregation is applied after constraint - private final List groupingColumns; + private final List groupingColumns; private final List aggregateColumns; // semantically sorting is applied after aggregation @@ -49,13 +49,15 @@ public final class DynamicTable private final String query; + private final boolean isAggregateInProjections; + @JsonCreator public DynamicTable( @JsonProperty("tableName") String tableName, @JsonProperty("suffix") Optional suffix, - @JsonProperty("selections") List selections, + @JsonProperty("projections") List projections, @JsonProperty("filter") Optional filter, - @JsonProperty("groupingColumns") List groupingColumns, + @JsonProperty("groupingColumns") List groupingColumns, @JsonProperty("aggregateColumns") List aggregateColumns, @JsonProperty("orderBy") List orderBy, @JsonProperty("limit") OptionalLong limit, @@ -64,7 +66,7 @@ public DynamicTable( { this.tableName = requireNonNull(tableName, "tableName is null"); this.suffix = requireNonNull(suffix, "suffix is null"); - this.selections = ImmutableList.copyOf(requireNonNull(selections, "selections is null")); + this.projections = ImmutableList.copyOf(requireNonNull(projections, "projections is null")); this.filter = requireNonNull(filter, "filter is null"); this.groupingColumns = ImmutableList.copyOf(requireNonNull(groupingColumns, "groupingColumns is null")); this.aggregateColumns = ImmutableList.copyOf(requireNonNull(aggregateColumns, "aggregateColumns is null")); @@ -72,6 +74,8 @@ public DynamicTable( this.limit = requireNonNull(limit, "limit is null"); this.offset = requireNonNull(offset, "offset is null"); this.query = requireNonNull(query, "query is null"); + this.isAggregateInProjections = projections.stream() + .anyMatch(PinotColumnHandle::isAggregate); } @JsonProperty @@ -87,9 +91,9 @@ public Optional getSuffix() } @JsonProperty - public List getSelections() + public List getProjections() { - return selections; + return projections; } @JsonProperty @@ -99,7 +103,7 @@ public Optional getFilter() } @JsonProperty - public List getGroupingColumns() + public List getGroupingColumns() { return groupingColumns; } @@ -134,6 +138,11 @@ public String getQuery() return query; } + public boolean isAggregateInProjections() + { + return isAggregateInProjections; + } + @Override public boolean equals(Object other) { @@ -147,7 +156,7 @@ public boolean equals(Object other) DynamicTable that = (DynamicTable) other; return tableName.equals(that.tableName) && - selections.equals(that.selections) && + projections.equals(that.projections) && filter.equals(that.filter) && groupingColumns.equals(that.groupingColumns) && aggregateColumns.equals(that.aggregateColumns) && @@ -160,7 +169,7 @@ public boolean equals(Object other) @Override public int hashCode() { - return Objects.hash(tableName, selections, filter, groupingColumns, aggregateColumns, orderBy, limit, offset, query); + return Objects.hash(tableName, projections, filter, groupingColumns, aggregateColumns, orderBy, limit, offset, query); } @Override @@ -168,7 +177,7 @@ public String toString() { return toStringHelper(this) .add("tableName", tableName) - .add("selections", selections) + .add("projections", projections) .add("filter", filter) .add("groupingColumns", groupingColumns) .add("aggregateColumns", aggregateColumns) diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java index cf31da4625ec..920a4d819e30 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java @@ -14,16 +14,20 @@ package io.trino.plugin.pinot.query; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.trino.plugin.pinot.PinotColumnHandle; import io.trino.plugin.pinot.PinotException; import io.trino.plugin.pinot.PinotMetadata; +import io.trino.plugin.pinot.client.PinotClient; import io.trino.spi.connector.ColumnHandle; -import io.trino.spi.connector.ColumnNotFoundException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import org.apache.pinot.common.request.BrokerRequest; -import org.apache.pinot.common.request.SelectionSort; +import org.apache.pinot.common.request.PinotQuery; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FunctionContext; +import org.apache.pinot.common.request.context.OrderByExpressionContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.request.context.QueryContext; @@ -35,22 +39,28 @@ import java.util.Optional; import java.util.OptionalLong; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.pinot.PinotColumnHandle.fromNonAggregateColumnHandle; +import static io.trino.plugin.pinot.PinotColumnHandle.getTrinoTypeFromPinotType; import static io.trino.plugin.pinot.PinotErrorCode.PINOT_UNSUPPORTED_COLUMN_TYPE; -import static io.trino.plugin.pinot.query.FilterToPinotSqlConverter.convertFilter; +import static io.trino.plugin.pinot.query.PinotExpressionRewriter.rewriteExpression; +import static io.trino.plugin.pinot.query.PinotPatterns.WILDCARD; +import static io.trino.plugin.pinot.query.PinotSqlFormatter.formatExpression; +import static io.trino.plugin.pinot.query.PinotSqlFormatter.formatFilter; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; public final class DynamicTableBuilder { private static final CalciteSqlCompiler REQUEST_COMPILER = new CalciteSqlCompiler(); - private static final String WILDCARD = "*"; public static final String OFFLINE_SUFFIX = "_OFFLINE"; public static final String REALTIME_SUFFIX = "_REALTIME"; @@ -58,59 +68,57 @@ private DynamicTableBuilder() { } - public static DynamicTable buildFromPql(PinotMetadata pinotMetadata, SchemaTableName schemaTableName) + public static DynamicTable buildFromPql(PinotMetadata pinotMetadata, SchemaTableName schemaTableName, PinotClient pinotClient) { requireNonNull(pinotMetadata, "pinotMetadata is null"); requireNonNull(schemaTableName, "schemaTableName is null"); String query = schemaTableName.getTableName(); BrokerRequest request = REQUEST_COMPILER.compileToBrokerRequest(query); + PinotQuery pinotQuery = request.getPinotQuery(); + QueryContext queryContext = BrokerRequestToQueryContextConverter.convert(request); String pinotTableName = stripSuffix(request.getQuerySource().getTableName()); Optional suffix = getSuffix(request.getQuerySource().getTableName()); Map columnHandles = pinotMetadata.getPinotColumnHandles(pinotTableName); - List selectionColumns = ImmutableList.of(); List orderBy = ImmutableList.of(); - if (request.getSelections() != null) { - selectionColumns = resolvePinotColumns(schemaTableName, request.getSelections().getSelectionColumns(), columnHandles); - if (request.getSelections().getSelectionSortSequence() != null) { - ImmutableList.Builder orderByBuilder = ImmutableList.builder(); - for (SelectionSort sortItem : request.getSelections().getSelectionSortSequence()) { - PinotColumnHandle columnHandle = (PinotColumnHandle) columnHandles.get(sortItem.getColumn()); - if (columnHandle == null) { - throw new ColumnNotFoundException(schemaTableName, sortItem.getColumn()); - } - orderByBuilder.add(new OrderByExpression(columnHandle.getColumnName(), sortItem.isIsAsc())); - } - orderBy = orderByBuilder.build(); + PinotTypeResolver pinotTypeResolver = new PinotTypeResolver(pinotClient, pinotTableName); + List selectColumns = ImmutableList.of(); + + ImmutableMap.Builder aggregateTypesBuilder = ImmutableMap.builder(); + if (queryContext.getAggregationFunctions() != null) { + checkState(queryContext.getAggregationFunctions().length > 0, "Aggregation Functions is empty"); + for (AggregationFunction aggregationFunction : queryContext.getAggregationFunctions()) { + aggregateTypesBuilder.put(aggregationFunction.getResultColumnName(), toTrinoType(aggregationFunction.getFinalResultColumnType())); } } - - List groupByColumns; - if (request.getGroupBy() == null) { - groupByColumns = ImmutableList.of(); - } - else { - groupByColumns = resolvePinotColumns(schemaTableName, request.getGroupBy().getExpressions(), columnHandles); + Map aggregateTypes = aggregateTypesBuilder.build(); + if (queryContext.getSelectExpressions() != null) { + checkState(!queryContext.getSelectExpressions().isEmpty(), "Pinot selections is empty"); + selectColumns = getPinotColumns(schemaTableName, queryContext.getSelectExpressions(), queryContext.getAliasList(), columnHandles, pinotTypeResolver, aggregateTypes); } - Optional filter; - if (request.getFilterQuery() != null) { - filter = Optional.of(convertFilter(request.getPinotQuery(), columnHandles)); + if (queryContext.getOrderByExpressions() != null) { + ImmutableList.Builder orderByBuilder = ImmutableList.builder(); + for (OrderByExpressionContext orderByExpressionContext : queryContext.getOrderByExpressions()) { + ExpressionContext expressionContext = orderByExpressionContext.getExpression(); + PinotColumnHandle pinotColumnHandle = getPinotColumnHandle(schemaTableName, expressionContext, Optional.empty(), columnHandles, pinotTypeResolver, aggregateTypes); + orderByBuilder.add(new OrderByExpression(pinotColumnHandle.getExpression(), orderByExpressionContext.isAsc())); + } + orderBy = orderByBuilder.build(); } - else { - filter = Optional.empty(); + + List groupByColumns = ImmutableList.of(); + if (queryContext.getGroupByExpressions() != null) { + groupByColumns = getPinotColumns(schemaTableName, queryContext.getGroupByExpressions(), ImmutableList.of(), columnHandles, pinotTypeResolver, aggregateTypes); } - QueryContext queryContext = BrokerRequestToQueryContextConverter.convert(request); - ImmutableList.Builder aggregateColumnsBuilder = ImmutableList.builder(); - if (request.getAggregationsInfo() != null) { - for (AggregationFunction aggregationFunction : queryContext.getAggregationFunctions()) { - aggregateColumnsBuilder.add(new PinotColumnHandle( - aggregationFunction.getResultColumnName(), - toTrinoType(aggregationFunction.getFinalResultColumnType()))); - } + + Optional filter = Optional.empty(); + if (pinotQuery.getFilterExpression() != null) { + String formatted = formatFilter(schemaTableName, queryContext.getFilter(), columnHandles); + filter = Optional.of(formatted); } - return new DynamicTable(pinotTableName, suffix, selectionColumns, filter, groupByColumns, aggregateColumnsBuilder.build(), orderBy, getTopNOrLimit(request), getOffset(request), query); + return new DynamicTable(pinotTableName, suffix, selectColumns, filter, groupByColumns, ImmutableList.of(), orderBy, OptionalLong.of(queryContext.getLimit()), getOffset(queryContext), query); } private static Type toTrinoType(DataSchema.ColumnDataType columnDataType) @@ -142,41 +150,62 @@ private static Type toTrinoType(DataSchema.ColumnDataType columnDataType) throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "Unsupported column data type: " + columnDataType); } - private static List resolvePinotColumns(SchemaTableName schemaTableName, List trinoColumnNames, Map columnHandles) + private static List getPinotColumns(SchemaTableName schemaTableName, List expressions, List aliases, Map columnHandles, PinotTypeResolver pinotTypeResolver, Map aggregateTypes) { - ImmutableList.Builder pinotColumnNamesBuilder = ImmutableList.builder(); - for (String trinoColumnName : trinoColumnNames) { - if (trinoColumnName.equals(WILDCARD)) { - pinotColumnNamesBuilder.addAll(columnHandles.values().stream().map(handle -> ((PinotColumnHandle) handle).getColumnName()).collect(toImmutableList())); + ImmutableList.Builder pinotColumnsBuilder = ImmutableList.builder(); + for (int index = 0; index < expressions.size(); index++) { + ExpressionContext expressionContext = expressions.get(index); + Optional alias = getAlias(aliases, index); + // Only substitute * with columns for top level SELECT *. + // Since Pinot doesn't support subqueries yet, we can only have one occurrence of SELECT * + if (expressionContext.getType() == ExpressionContext.Type.IDENTIFIER && expressionContext.getIdentifier().equals(WILDCARD)) { + pinotColumnsBuilder.addAll(columnHandles.values().stream() + .map(handle -> fromNonAggregateColumnHandle((PinotColumnHandle) handle)) + .collect(toImmutableList())); } else { - PinotColumnHandle columnHandle = (PinotColumnHandle) columnHandles.get(trinoColumnName); - if (columnHandle == null) { - throw new ColumnNotFoundException(schemaTableName, trinoColumnName); - } - pinotColumnNamesBuilder.add(columnHandle.getColumnName()); + pinotColumnsBuilder.add(getPinotColumnHandle(schemaTableName, expressionContext, alias, columnHandles, pinotTypeResolver, aggregateTypes)); } } - return pinotColumnNamesBuilder.build(); + return pinotColumnsBuilder.build(); } - private static OptionalLong getTopNOrLimit(BrokerRequest request) + private static PinotColumnHandle getPinotColumnHandle(SchemaTableName schemaTableName, ExpressionContext expressionContext, Optional alias, Map columnHandles, PinotTypeResolver pinotTypeResolver, Map aggregateTypes) { - if (request.getGroupBy() != null) { - return OptionalLong.of(request.getGroupBy().getTopN()); - } - else if (request.getSelections() != null) { - return OptionalLong.of(request.getSelections().getSize()); + ExpressionContext rewritten = rewriteExpression(schemaTableName, expressionContext, columnHandles); + // If there is no alias, pinot autogenerates the column name: + String columnName = rewritten.toString(); + String pinotExpression = formatExpression(schemaTableName, rewritten); + Type trinoType; + boolean isAggregate = isAggregate(rewritten); + if (isAggregate) { + trinoType = requireNonNull(aggregateTypes.get(columnName.toLowerCase(ENGLISH)), format("Unexpected aggregate expression: '%s'", rewritten)); } else { - return OptionalLong.empty(); + trinoType = getTrinoTypeFromPinotType(pinotTypeResolver.resolveExpressionType(rewritten, schemaTableName, columnHandles)); + } + + return new PinotColumnHandle(alias.orElse(columnName), trinoType, pinotExpression, alias.isPresent(), isAggregate, true, Optional.empty(), Optional.empty()); + } + + private static Optional getAlias(List aliases, int index) + { + // SELECT * is expanded to all columns with no aliases + if (index >= aliases.size()) { + return Optional.empty(); } + return Optional.ofNullable(aliases.get(index)); + } + + private static boolean isAggregate(ExpressionContext expressionContext) + { + return expressionContext.getType() == ExpressionContext.Type.FUNCTION && expressionContext.getFunction().getType() == FunctionContext.Type.AGGREGATION; } - private static OptionalLong getOffset(BrokerRequest request) + private static OptionalLong getOffset(QueryContext queryContext) { - if (request.getSelections() != null && request.getSelections().getOffset() > 0) { - return OptionalLong.of(request.getSelections().getOffset()); + if (queryContext.getOffset() > 0) { + return OptionalLong.of(queryContext.getOffset()); } else { return OptionalLong.empty(); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java index 7f993b659aa7..0231a2044661 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.Optional; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.plugin.pinot.query.PinotQueryBuilder.getFilterClause; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -36,22 +35,21 @@ public static String extractPql(DynamicTable table, TupleDomain tu { StringBuilder builder = new StringBuilder(); builder.append("select "); - if (!table.getSelections().isEmpty()) { - builder.append(table.getSelections().stream() - .map(DynamicTablePqlExtractor::quoteIdentifier) + if (!table.getProjections().isEmpty()) { + builder.append(table.getProjections().stream() + .map(DynamicTablePqlExtractor::formatExpression) .collect(joining(", "))); } - if (!table.getGroupingColumns().isEmpty()) { - builder.append(table.getGroupingColumns().stream() - .map(DynamicTablePqlExtractor::quoteIdentifier) - .collect(joining(", "))); - if (!table.getAggregateColumns().isEmpty()) { + + if (!table.getAggregateColumns().isEmpty()) { + // If there are only pushed down aggregate expressions + if (!table.getProjections().isEmpty()) { builder.append(", "); } + builder.append(table.getAggregateColumns().stream() + .map(DynamicTablePqlExtractor::formatExpression) + .collect(joining(", "))); } - builder.append(table.getAggregateColumns().stream() - .map(PinotColumnHandle::getColumnName) - .collect(joining(", "))); builder.append(" from "); builder.append(table.getTableName()); builder.append(table.getSuffix().orElse("")); @@ -64,7 +62,7 @@ public static String extractPql(DynamicTable table, TupleDomain tu if (!table.getGroupingColumns().isEmpty()) { builder.append(" group by "); builder.append(table.getGroupingColumns().stream() - .map(DynamicTablePqlExtractor::quoteIdentifier) + .map(PinotColumnHandle::getExpression) .collect(joining(", "))); } if (!table.getOrderBy().isEmpty()) { @@ -74,12 +72,12 @@ public static String extractPql(DynamicTable table, TupleDomain tu .collect(joining(", "))); } if (table.getLimit().isPresent()) { - builder.append(" limit ") - .append(table.getLimit().getAsLong()); - if (!table.getSelections().isEmpty() && table.getOffset().isPresent()) { - builder.append(", ") - .append(table.getOffset().getAsLong()); + builder.append(" limit "); + if (table.getOffset().isPresent()) { + builder.append(table.getOffset().getAsLong()) + .append(", "); } + builder.append(table.getLimit().getAsLong()); } return builder.toString(); } @@ -106,7 +104,7 @@ private static String convertOrderByExpressionToPql(OrderByExpression orderByExp { requireNonNull(orderByExpression, "orderByExpression is null"); StringBuilder builder = new StringBuilder() - .append(quoteIdentifier(orderByExpression.getColumn())); + .append(orderByExpression.getExpression()); if (!orderByExpression.isAsc()) { builder.append(" desc"); } @@ -118,9 +116,16 @@ public static String encloseInParentheses(String value) return format("(%s)", value); } - private static String quoteIdentifier(String identifier) + private static String formatExpression(PinotColumnHandle pinotColumnHandle) + { + if (pinotColumnHandle.isAliased()) { + return pinotColumnHandle.getExpression() + " AS " + quoteIdentifier(pinotColumnHandle.getColumnName()); + } + return pinotColumnHandle.getExpression(); + } + + public static String quoteIdentifier(String identifier) { - checkArgument(!identifier.contains("\""), "Identifier contains double quotes: '%s'", identifier); - return format("\"%s\"", identifier); + return format("\"%s\"", identifier.replaceAll("\"", "\"\"")); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/FilterToPinotSqlConverter.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/FilterToPinotSqlConverter.java deleted file mode 100644 index 08a43f5dd242..000000000000 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/FilterToPinotSqlConverter.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.pinot.query; - -import com.google.common.collect.ImmutableMap; -import io.trino.plugin.pinot.PinotColumnHandle; -import io.trino.spi.connector.ColumnHandle; -import org.apache.commons.codec.binary.Hex; -import org.apache.pinot.common.request.Expression; -import org.apache.pinot.common.request.Function; -import org.apache.pinot.common.request.Identifier; -import org.apache.pinot.common.request.Literal; - -import java.text.DecimalFormat; -import java.text.DecimalFormatSymbols; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static java.lang.String.format; -import static java.util.Locale.ENGLISH; -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.joining; -import static org.apache.pinot.common.request.ExpressionType.FUNCTION; -import static org.apache.pinot.common.request.ExpressionType.LITERAL; - -public class FilterToPinotSqlConverter -{ - private static final Map BINARY_OPERATORS = ImmutableMap.builder() - .put("equals", "=") - .put("not_equals", "!=") - .put("greater_than", ">") - .put("less_than", "<") - .put("greater_than_or_equal", ">=") - .put("less_than_or_equal", "<=") - .put("plus", "+") - .put("minus", "-") - .put("times", "*") - .put("divide", "/") - .build(); - - // Pinot does not recognize double literals with scientific notation - private static final ThreadLocal doubleFormatter = ThreadLocal.withInitial( - () -> { DecimalFormat decimalFormat = new DecimalFormat("0", new DecimalFormatSymbols(Locale.US)); - decimalFormat.setMaximumFractionDigits(340); - return decimalFormat; }); - - private final Map columnHandles; - - private FilterToPinotSqlConverter(Map columnHandles) - { - this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); - } - - public static String convertFilter(org.apache.pinot.common.request.PinotQuery pinotQuery, Map columnHandles) - { - return new FilterToPinotSqlConverter(columnHandles).formatExpression(pinotQuery.getFilterExpression()); - } - - private String formatExpression(Expression expression) - { - switch (expression.getType()) { - case FUNCTION: - return formatFunction(expression.getFunctionCall()); - case LITERAL: - return formatLiteral(expression.getLiteral()); - case IDENTIFIER: - return formatIdentifier(expression.getIdentifier()); - default: - throw new UnsupportedOperationException(format("Unknown type: '%s'", expression.getType())); - } - } - - private String formatFunction(Function functionCall) - { - String binaryOperator = BINARY_OPERATORS.get(functionCall.getOperator().toLowerCase(ENGLISH)); - if (binaryOperator != null) { - return formatEqualsMinusZero(functionCall).orElse(format("(%s) %s (%s)", - formatExpression(functionCall.getOperands().get(0)), - binaryOperator, - formatExpression(functionCall.getOperands().get(1)))); - } - else if (functionCall.getOperator().equalsIgnoreCase("cast")) { - checkState(functionCall.getOperands().size() == 2, "Unexpected size for cast operator"); - return format("CAST(%s AS %s)", formatExpression(functionCall.getOperands().get(0)), functionCall.getOperands().get(1).getLiteral().getStringValue()); - } - else if (functionCall.getOperator().equalsIgnoreCase("in") || functionCall.getOperator().equalsIgnoreCase("not_in")) { - return formatInClause(functionCall); - } - else if (functionCall.getOperator().equalsIgnoreCase("case")) { - return formatCaseStatement(functionCall); - } - return functionCall.getOperator() + "(" + functionCall.getOperands().stream().map(this::formatExpression).collect(joining(", ")) + ")"; - } - - // Pinot parses "a = b" as "a - b = 0" which can result in invalid sql - private Optional formatEqualsMinusZero(Function functionCall) - { - if (!functionCall.getOperator().equalsIgnoreCase("equals")) { - return Optional.empty(); - } - - Expression left = functionCall.getOperands().get(0); - if (left.getType() != FUNCTION || !left.getFunctionCall().getOperator().equalsIgnoreCase("minus")) { - return Optional.empty(); - } - - Expression right = functionCall.getOperands().get(1); - if (right.getType() != LITERAL || !formatLiteral(right.getLiteral()).equals("0")) { - return Optional.empty(); - } - Function minus = left.getFunctionCall(); - return Optional.of(format("(%s) = (%s)", formatExpression(minus.getOperands().get(0)), formatExpression(minus.getOperands().get(1)))); - } - - private String formatInClause(Function functionCall) - { - checkState(functionCall.getOperator().equalsIgnoreCase("in") || - functionCall.getOperator().equalsIgnoreCase("not_in"), - "Unexpected operator '%s'", functionCall.getOperator()); - checkState(functionCall.getOperands().size() > 1, "Unexpected expression"); - String operator; - if (functionCall.getOperator().equalsIgnoreCase("in")) { - operator = "IN"; - } - else { - operator = "NOT IN"; - } - return format("%s %s (%s)", formatExpression(functionCall.getOperands().get(0)), - operator, - functionCall.getOperands().subList(1, functionCall.getOperands().size()).stream() - .map(this::formatExpression) - .collect(joining(", "))); - } - - private String formatCaseStatement(Function functionCall) - { - checkState(functionCall.getOperator().equalsIgnoreCase("case"), "Unexpected operator '%s'", functionCall.getOperator()); - checkState(functionCall.getOperands().size() >= 2, "Unexpected expression"); - int whenStatements = functionCall.getOperands().size() / 2; - StringBuilder builder = new StringBuilder("CASE "); - - builder.append("WHEN ") - .append(formatExpression(functionCall.getOperands().get(0))) - .append(" THEN ") - .append(formatExpression(functionCall.getOperands().get(whenStatements))); - - for (int index = 1; index < whenStatements; index++) { - builder.append(" ") - .append("WHEN ") - .append(formatExpression(functionCall.getOperands().get(index))) - .append(" THEN ") - .append(formatExpression(functionCall.getOperands().get(index + whenStatements))); - } - - if (functionCall.getOperands().size() % 2 != 0) { - builder.append(" ELSE " + formatExpression(functionCall.getOperands().get(functionCall.getOperands().size() - 1))); - } - return builder.append(" END").toString(); - } - - private String formatLiteral(Literal literal) - { - if (!literal.isSet()) { - return "null"; - } - switch (literal.getSetField()) { - case LONG_VALUE: - return String.valueOf(literal.getLongValue()); - case INT_VALUE: - return String.valueOf(literal.getIntValue()); - case BOOL_VALUE: - return String.valueOf(literal.getBoolValue()); - case STRING_VALUE: - return format("'%s'", literal.getStringValue().replaceAll("'", "''")); - case BYTE_VALUE: - return String.valueOf(literal.getByteValue()); - case BINARY_VALUE: - return Hex.encodeHexString(literal.getBinaryValue()); - case DOUBLE_VALUE: - return doubleFormatter.get().format(literal.getDoubleValue()); - case SHORT_VALUE: - return String.valueOf(literal.getShortValue()); - default: - throw new UnsupportedOperationException(format("Unknown literal type: '%s'", literal.getSetField())); - } - } - - private String formatIdentifier(Identifier identifier) - { - PinotColumnHandle pinotColumnHandle = (PinotColumnHandle) requireNonNull(columnHandles.get(identifier.getName()), "Column not found"); - checkArgument(!pinotColumnHandle.getColumnName().contains("\""), "Column name contains double quotes: '%s'", pinotColumnHandle.getColumnName()); - return format("\"%s\"", pinotColumnHandle.getColumnName()); - } -} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/OrderByExpression.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/OrderByExpression.java index e508276fd7a4..998aaf4c2a44 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/OrderByExpression.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/OrderByExpression.java @@ -23,22 +23,22 @@ public final class OrderByExpression { - private final String column; + private final String expression; private final boolean asc; @JsonCreator public OrderByExpression( - @JsonProperty("column") String column, + @JsonProperty("expression") String expression, @JsonProperty("asc") boolean asc) { - this.column = requireNonNull(column, "column is null"); + this.expression = requireNonNull(expression, "column is null"); this.asc = asc; } @JsonProperty - public String getColumn() + public String getExpression() { - return column; + return expression; } @JsonProperty @@ -57,21 +57,21 @@ public boolean equals(Object other) return false; } OrderByExpression that = (OrderByExpression) other; - return column.equals(that.column) && + return expression.equals(that.expression) && asc == that.asc; } @Override public int hashCode() { - return Objects.hash(column, asc); + return Objects.hash(expression, asc); } @Override public String toString() { return toStringHelper(this) - .add("column", column) + .add("expression", expression) .add("asc", asc) .toString(); } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotExpressionRewriter.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotExpressionRewriter.java new file mode 100644 index 000000000000..9ed9abd33ac8 --- /dev/null +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotExpressionRewriter.java @@ -0,0 +1,342 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.pinot.query; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Match; +import io.trino.matching.Pattern; +import io.trino.plugin.pinot.PinotException; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SchemaTableName; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FunctionContext; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.data.DateTimeFormatSpec; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.immutableEnumMap; +import static io.trino.plugin.pinot.PinotErrorCode.PINOT_EXCEPTION; +import static io.trino.plugin.pinot.query.PinotPatterns.WILDCARD; +import static io.trino.plugin.pinot.query.PinotPatterns.aggregationFunction; +import static io.trino.plugin.pinot.query.PinotPatterns.aggregationFunctionType; +import static io.trino.plugin.pinot.query.PinotPatterns.expression; +import static io.trino.plugin.pinot.query.PinotPatterns.expressionType; +import static io.trino.plugin.pinot.query.PinotPatterns.function; +import static io.trino.plugin.pinot.query.PinotPatterns.identifier; +import static io.trino.plugin.pinot.query.PinotPatterns.singleInput; +import static io.trino.plugin.pinot.query.PinotPatterns.transformFunction; +import static io.trino.plugin.pinot.query.PinotPatterns.transformFunctionType; +import static io.trino.plugin.pinot.query.PinotSqlFormatter.getColumnHandle; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static org.apache.pinot.common.function.TransformFunctionType.DATETIMECONVERT; +import static org.apache.pinot.common.function.TransformFunctionType.DATETRUNC; +import static org.apache.pinot.common.function.TransformFunctionType.TIMECONVERT; +import static org.apache.pinot.common.request.context.ExpressionContext.Type.FUNCTION; +import static org.apache.pinot.common.request.context.ExpressionContext.Type.IDENTIFIER; +import static org.apache.pinot.common.request.context.ExpressionContext.Type.LITERAL; +import static org.apache.pinot.common.request.context.ExpressionContext.forFunction; +import static org.apache.pinot.common.request.context.ExpressionContext.forIdentifier; +import static org.apache.pinot.common.request.context.ExpressionContext.forLiteral; +import static org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction.EXAMPLE_INVOCATION; +import static org.apache.pinot.core.operator.transform.transformer.timeunit.TimeUnitTransformerFactory.getTimeUnitTransformer; +import static org.apache.pinot.segment.spi.AggregationFunctionType.COUNT; +import static org.apache.pinot.segment.spi.AggregationFunctionType.getAggregationFunctionType; +import static org.apache.pinot.spi.data.DateTimeFormatSpec.validateFormat; +import static org.apache.pinot.spi.data.DateTimeGranularitySpec.validateGranularity; + +public class PinotExpressionRewriter +{ + private static final Map> FUNCTION_RULE_MAP; + private static final Map> AGGREGATION_FUNCTION_RULE_MAP; + private static final RewriteRule DEFAULT_REWRITE_RULE = new DefaultRewriteRule(); + + private PinotExpressionRewriter() {} + + static { + Map> functionMap = new HashMap<>(); + functionMap.put(DATETIMECONVERT, new DateTimeConvertRewriteRule()); + functionMap.put(TIMECONVERT, new TimeConvertRewriteRule()); + functionMap.put(DATETRUNC, new DateTruncRewriteRule()); + FUNCTION_RULE_MAP = immutableEnumMap(functionMap); + + Map> aggregationFunctionMap = new HashMap<>(); + aggregationFunctionMap.put(COUNT, new CountStarRewriteRule()); + AGGREGATION_FUNCTION_RULE_MAP = immutableEnumMap(aggregationFunctionMap); + } + + public static ExpressionContext rewriteExpression(SchemaTableName schemaTableName, ExpressionContext expressionContext, Map columnHandles) + { + requireNonNull(expressionContext, "expressionContext is null"); + Context context = new Context() { + @Override + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @Override + public Map getColumnHandles() + { + return columnHandles; + } + }; + return rewriteExpression(expressionContext, context); + } + + private static ExpressionContext rewriteExpression(ExpressionContext expressionContext, Context context) + { + switch (expressionContext.getType()) { + case LITERAL: + return expressionContext; + case IDENTIFIER: + return forIdentifier(getColumnHandle(expressionContext.getIdentifier(), context.getSchemaTableName(), context.getColumnHandles()).getColumnName()); + case FUNCTION: + return forFunction(rewriteFunction(expressionContext.getFunction(), context)); + } + throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported expression type '%s'", expressionContext.getType())); + } + + private static FunctionContext rewriteFunction(FunctionContext functionContext, Context context) + { + Optional result = Optional.empty(); + if (functionContext.getType() == FunctionContext.Type.TRANSFORM) { + RewriteRule rule = FUNCTION_RULE_MAP.get(TransformFunctionType.getTransformFunctionType(functionContext.getFunctionName())); + if (rule != null) { + result = applyRule(rule, functionContext, context); + } + } + else { + checkState(functionContext.getType() == FunctionContext.Type.AGGREGATION, "Unexpected function type for '%s'", functionContext); + RewriteRule rule = AGGREGATION_FUNCTION_RULE_MAP.get(getAggregationFunctionType(functionContext.getFunctionName())); + if (rule != null) { + result = applyRule(rule, functionContext, context); + } + } + if (result.isPresent()) { + return result.get(); + } + result = applyRule(DEFAULT_REWRITE_RULE, functionContext, context); + if (result.isPresent()) { + return result.get(); + } + throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported function expression '%s'", functionContext)); + } + + private static Optional applyRule(RewriteRule rule, T object, Context context) + { + Iterator iterator = rule.getPattern().match(object).iterator(); + while (iterator.hasNext()) { + Match match = iterator.next(); + return Optional.of(rule.rewrite(object, match.captures(), context)); + } + return Optional.empty(); + } + + private static class DateTimeConvertRewriteRule + implements RewriteRule + { + @Override + public Pattern getPattern() + { + return transformFunction().with(transformFunctionType().equalTo(DATETIMECONVERT)); + } + + @Override + public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) + { + // Extracted from org.apache.pinot.core.operator.transform.function.DateTimeConversionTransformFunction + // The first argument must be an identifier or function and the 2nd, 3rd and 4th arguments must be literals + verify(object.getArguments().size() == 4); + verifyIsIdentifierOrFunction(object.getArguments().get(0)); + verifyTailArgumentsAllLiteral(object.getArguments()); + + ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); + argumentsBuilder.add(rewriteExpression(object.getArguments().get(0), context)); + String inputFormat = object.getArguments().get(1).getLiteral().toUpperCase(ENGLISH); + checkDateTimeFormatSpec(inputFormat); + argumentsBuilder.add(forLiteral(inputFormat)); + String outputFormat = object.getArguments().get(2).getLiteral().toUpperCase(ENGLISH); + checkDateTimeFormatSpec(outputFormat); + argumentsBuilder.add(forLiteral(outputFormat)); + String granularity = object.getArguments().get(3).getLiteral().toUpperCase(ENGLISH); + validateGranularity(granularity); + argumentsBuilder.add(forLiteral(granularity)); + return new FunctionContext(object.getType(), object.getFunctionName(), argumentsBuilder.build()); + } + } + + private static class TimeConvertRewriteRule + implements RewriteRule + { + @Override + public Pattern getPattern() + { + return transformFunction().with(transformFunctionType().equalTo(TIMECONVERT)); + } + + @Override + public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) + { + // Extracted from org.apache.pinot.core.operator.transform.function.DateTimeConversionTransformFunction + // The first argument must be an identifier or function and the 2nd, and 3rd arguments must be literals + verify(object.getArguments().size() == 3); + verifyIsIdentifierOrFunction(object.getArguments().get(0)); + verifyTailArgumentsAllLiteral(object.getArguments()); + + ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); + argumentsBuilder.add(rewriteExpression(object.getArguments().get(0), context)); + String inputTimeUnitArgument = object.getArguments().get(1).getLiteral().toUpperCase(ENGLISH); + TimeUnit inputTimeUnit = TimeUnit.valueOf(inputTimeUnitArgument); + String outputTimeUnitArgument = object.getArguments().get(2).getLiteral().toUpperCase(ENGLISH); + // Check that this is a valid time unit transform + getTimeUnitTransformer(inputTimeUnit, outputTimeUnitArgument); + argumentsBuilder.add(forLiteral(inputTimeUnitArgument)); + argumentsBuilder.add(forLiteral(outputTimeUnitArgument)); + return new FunctionContext(object.getType(), object.getFunctionName(), argumentsBuilder.build()); + } + } + + private static class DateTruncRewriteRule + implements RewriteRule + { + @Override + public Pattern getPattern() + { + return transformFunction().with(transformFunctionType().equalTo(DATETRUNC)); + } + + @Override + public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) + { + // Extracted from org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction + List arguments = object.getArguments(); + checkState(arguments.size() >= 2 && arguments.size() <= 5, + "Between two to five arguments are required, example: %s", EXAMPLE_INVOCATION); + + ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); + + checkState(arguments.get(0).getType() == LITERAL, "First argument must be a literal"); + String unit = arguments.get(0).getLiteral().toLowerCase(ENGLISH); + argumentsBuilder.add(forLiteral(unit)); + verifyIsIdentifierOrFunction(object.getArguments().get(1)); + ExpressionContext valueArgument = rewriteExpression(arguments.get(1), context); + argumentsBuilder.add(valueArgument); + if (arguments.size() >= 3) { + checkState(arguments.get(2).getType() == LITERAL, "Unexpected 3rd argument: '%s'", arguments.get(2)); + String inputTimeUnitArgument = arguments.get(2).getLiteral().toUpperCase(ENGLISH); + // Ensure this is a valid TimeUnit + TimeUnit inputTimeUnit = TimeUnit.valueOf(inputTimeUnitArgument); + argumentsBuilder.add(forLiteral(inputTimeUnit.name())); + if (arguments.size() >= 4) { + checkState(arguments.get(3).getType() == LITERAL, "Unexpected 4th argument '%s'", arguments.get(3)); + // Time zone is lower cased inside Pinot + argumentsBuilder.add(arguments.get(3)); + if (arguments.size() >= 5) { + checkState(arguments.get(4).getType() == LITERAL, "Unexpected 5th argument: '%s'", arguments.get(4)); + String outputTimeUnitArgument = arguments.get(4).getLiteral().toUpperCase(ENGLISH); + // Ensure this is a valid TimeUnit + TimeUnit outputTimeUnit = TimeUnit.valueOf(outputTimeUnitArgument); + argumentsBuilder.add(forLiteral(outputTimeUnit.name())); + } + } + } + return new FunctionContext(object.getType(), object.getFunctionName(), argumentsBuilder.build()); + } + } + + private static class CountStarRewriteRule + implements RewriteRule + { + @Override + public Pattern getPattern() + { + return aggregationFunction() + .with(aggregationFunctionType().equalTo(COUNT)) + .with(singleInput().matching(expression() + .with(expressionType().equalTo(IDENTIFIER)) + .with(identifier().equalTo(WILDCARD)))); + } + + @Override + public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) + { + return object; + } + } + + private static class DefaultRewriteRule + implements RewriteRule + { + @Override + public Pattern getPattern() + { + return function(); + } + + @Override + public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) + { + List arguments = object.getArguments().stream().map(argument -> rewriteExpression(argument, context)) + .collect(toImmutableList()); + return new FunctionContext(object.getType(), object.getFunctionName(), arguments); + } + } + + private static void checkDateTimeFormatSpec(String dateTimeFormat) + { + requireNonNull(dateTimeFormat, "dateTimeFormat is null"); + validateFormat(dateTimeFormat); + // Even if the format is valid, make sure it is not a simple date format: format characters can be ambiguous due to lower casing + DateTimeFormatSpec dateTimeFormatSpec = new DateTimeFormatSpec(dateTimeFormat); + checkState(dateTimeFormatSpec.getSDFPattern() == null, "Unsupported date format: simple date format not supported"); + } + + private static void verifyIsIdentifierOrFunction(ExpressionContext expressionContext) + { + verify(expressionContext.getType() == IDENTIFIER || expressionContext.getType() == FUNCTION); + } + + private static void verifyTailArgumentsAllLiteral(List arguments) + { + arguments.stream().skip(1) + .forEach(argument -> verify(argument.getType() == LITERAL)); + } + + private interface Context + { + SchemaTableName getSchemaTableName(); + + Map getColumnHandles(); + } + + private interface RewriteRule + { + Pattern getPattern(); + + T rewrite(T object, Captures captures, Context context); + } +} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotPatterns.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotPatterns.java new file mode 100644 index 000000000000..6a038382381e --- /dev/null +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotPatterns.java @@ -0,0 +1,294 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.pinot.query; + +import io.trino.matching.Pattern; +import io.trino.matching.Property; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; +import org.apache.pinot.common.request.context.FunctionContext; +import org.apache.pinot.common.request.context.predicate.EqPredicate; +import org.apache.pinot.common.request.context.predicate.InPredicate; +import org.apache.pinot.common.request.context.predicate.JsonMatchPredicate; +import org.apache.pinot.common.request.context.predicate.NotEqPredicate; +import org.apache.pinot.common.request.context.predicate.NotInPredicate; +import org.apache.pinot.common.request.context.predicate.Predicate; +import org.apache.pinot.common.request.context.predicate.RangePredicate; +import org.apache.pinot.common.request.context.predicate.RegexpLikePredicate; +import org.apache.pinot.common.request.context.predicate.TextMatchPredicate; +import org.apache.pinot.segment.spi.AggregationFunctionType; + +import java.util.List; +import java.util.Optional; + +import static io.trino.matching.Pattern.typeOf; +import static org.apache.pinot.common.function.TransformFunctionType.getTransformFunctionType; +import static org.apache.pinot.common.request.context.ExpressionContext.Type.FUNCTION; +import static org.apache.pinot.common.request.context.ExpressionContext.Type.IDENTIFIER; +import static org.apache.pinot.common.request.context.FunctionContext.Type.AGGREGATION; +import static org.apache.pinot.common.request.context.FunctionContext.Type.TRANSFORM; +import static org.apache.pinot.common.request.context.predicate.RangePredicate.UNBOUNDED; +import static org.apache.pinot.segment.spi.AggregationFunctionType.getAggregationFunctionType; + +public class PinotPatterns +{ + public static final String WILDCARD = "*"; + + private PinotPatterns() {} + + public static Pattern filter() + { + return typeOf(FilterContext.class); + } + + public static Pattern predicate() + { + return typeOf(Predicate.class); + } + + public static Pattern expression() + { + return typeOf(ExpressionContext.class); + } + + public static Pattern function() + { + return typeOf(FunctionContext.class); + } + + public static Pattern transformFunction() + { + return function() + .with(functionType().equalTo(TRANSFORM)); + } + + public static Pattern aggregationFunction() + { + return function() + .with(functionType().equalTo(AGGREGATION)); + } + + public static Pattern binaryFunction() + { + return transformFunction() + .with(arguments().matching(arguments -> arguments.size() == 2)); + } + + // Filter Properties + public static Property filterType() + { + return Property.property("filterContextType", FilterContext::getType); + } + + public static Property> childFilters() + { + return Property.optionalProperty("childFilters", context -> { + if (context.getType() == FilterContext.Type.AND || context.getType() == FilterContext.Type.OR) { + return Optional.ofNullable(context.getChildren()); + } + return Optional.empty(); + }); + } + + public static Property filterPredicate() + { + return Property.optionalProperty("filterPredicate", context -> { + if (context.getType() == FilterContext.Type.PREDICATE) { + return Optional.ofNullable(context.getPredicate()); + } + return Optional.empty(); + }); + } + + // Predicate Properties + public static Property predicateType() + { + return Property.property("predicateType", Predicate::getType); + } + + public static Property predicateExpression() + { + return Property.property("predicateType", Predicate::getLhs); + } + + public static Property binaryOperatorValue() + { + return Property.optionalProperty("binaryOperatorValue", predicate -> { + switch (predicate.getType()) { + case EQ: + return Optional.of(((EqPredicate) predicate).getValue()); + case NOT_EQ: + return Optional.of(((NotEqPredicate) predicate).getValue()); + case RANGE: + RangePredicate rangePredicate = (RangePredicate) predicate; + if (rangePredicate.getLowerBound().equals(UNBOUNDED)) { + return Optional.of(rangePredicate.getUpperBound()); + } + if (rangePredicate.getUpperBound().equals(UNBOUNDED)) { + return Optional.of(rangePredicate.getLowerBound()); + } + return Optional.empty(); + default: + return Optional.empty(); + } + }); + } + + public static Property binaryOperator() + { + return Property.optionalProperty("binaryOperator", predicate -> { + switch (predicate.getType()) { + case EQ: + return Optional.of("="); + case NOT_EQ: + return Optional.of("!="); + case RANGE: + RangePredicate rangePredicate = (RangePredicate) predicate; + if (rangePredicate.getLowerBound().equals(UNBOUNDED)) { + if (rangePredicate.isUpperInclusive()) { + return Optional.of("<="); + } + return Optional.of("<"); + } + if (rangePredicate.getUpperBound().equals(UNBOUNDED)) { + if (rangePredicate.isLowerInclusive()) { + return Optional.of(">="); + } + return Optional.of(">"); + } + return Optional.empty(); + default: + return Optional.empty(); + } + }); + } + + public static Property> predicateValuesList() + { + return Property.optionalProperty("predicateValuesList", predicate -> { + if (predicate.getType() == Predicate.Type.IN) { + return Optional.of(((InPredicate) predicate).getValues()); + } + else if (predicate.getType() == Predicate.Type.NOT_IN) { + return Optional.of(((NotInPredicate) predicate).getValues()); + } + return Optional.empty(); + }); + } + + public static Property binaryFunctionPredicateValue() + { + return Property.optionalProperty("binaryFunctionPredicateValue", predicate -> { + switch (predicate.getType()) { + case REGEXP_LIKE: + return Optional.of(((RegexpLikePredicate) predicate).getValue()); + case TEXT_MATCH: + return Optional.of(((TextMatchPredicate) predicate).getValue()); + case JSON_MATCH: + return Optional.of(((JsonMatchPredicate) predicate).getValue()); + default: + return Optional.empty(); + } + }); + } + + // Expression Properties + public static Property functionContext() + { + return Property.optionalProperty("functionContext", expressionContext -> { + if (expressionContext.getType() == FUNCTION) { + return Optional.of(expressionContext.getFunction()); + } + return Optional.empty(); + }); + } + + public static Property expressionType() + { + return Property.property("expressionType", ExpressionContext::getType); + } + + public static Property identifier() + { + return Property.optionalProperty("identifier", expressionContext -> { + if (expressionContext.getType() == IDENTIFIER) { + return Optional.of(expressionContext.getIdentifier()); + } + return Optional.empty(); + }); + } + + // Function Properties + public static Property transformFunctionType() + { + return Property.optionalProperty("transformFunctionType", functionContext -> { + if (functionContext.getType() == TRANSFORM) { + return Optional.of(getTransformFunctionType(functionContext.getFunctionName())); + } + return Optional.empty(); + }); + } + + // AggregationFunction Properties + public static Property aggregationFunctionType() + { + return Property.optionalProperty("aggregationFunctionType", functionContext -> { + if (functionContext.getType() == AGGREGATION) { + return Optional.of(getAggregationFunctionType(functionContext.getFunctionName())); + } + return Optional.empty(); + }); + } + + public static Property functionType() + { + return Property.property("functionType", FunctionContext::getType); + } + + public static Property> arguments() + { + return Property.property("arguments", FunctionContext::getArguments); + } + + public static Property singleInput() + { + return Property.optionalProperty("singleInput", functionContext -> { + if (functionContext.getArguments().size() == 1) { + return Optional.of(functionContext.getArguments().get(0)); + } + return Optional.empty(); + }); + } + + public static Property firstArgument() + { + return Property.optionalProperty("firstArgument", functionContext -> { + if (!functionContext.getArguments().isEmpty()) { + return Optional.of(functionContext.getArguments().get(0)); + } + return Optional.empty(); + }); + } + + public static Property secondArgument() + { + return Property.optionalProperty("secondArgument", functionContext -> { + if (functionContext.getArguments().size() > 1) { + return Optional.of(functionContext.getArguments().get(1)); + } + return Optional.empty(); + }); + } +} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java index 0a561062d4b7..d204ad41af05 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java @@ -34,7 +34,6 @@ import java.util.Optional; import java.util.OptionalLong; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -187,7 +186,6 @@ private static String singleQuote(Object value) private static String quoteIdentifier(String identifier) { - checkArgument(!identifier.contains("\""), "Identifier contains double quotes: '%s'", identifier); - return format("\"%s\"", identifier); + return format("\"%s\"", identifier.replaceAll("\"", "\"\"")); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQuery.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryInfo.java similarity index 92% rename from plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQuery.java rename to plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryInfo.java index 7b80f187ec49..92103b1b0d3a 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQuery.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryInfo.java @@ -20,14 +20,14 @@ import static com.google.common.base.MoreObjects.toStringHelper; -public final class PinotQuery +public final class PinotQueryInfo { private final String table; private final String query; private final int groupByClauses; @JsonCreator - public PinotQuery( + public PinotQueryInfo( @JsonProperty("table") String table, @JsonProperty("query") String query, @JsonProperty("groupByClauses") int groupByClauses) @@ -61,10 +61,10 @@ public boolean equals(Object other) if (this == other) { return true; } - if (!(other instanceof PinotQuery)) { + if (!(other instanceof PinotQueryInfo)) { return false; } - PinotQuery that = (PinotQuery) other; + PinotQueryInfo that = (PinotQueryInfo) other; return table.equals(that.table) && query.equals(that.query) && groupByClauses == that.groupByClauses; diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotSqlFormatter.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotSqlFormatter.java new file mode 100644 index 000000000000..a669a8107772 --- /dev/null +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotSqlFormatter.java @@ -0,0 +1,687 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.pinot.query; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Match; +import io.trino.matching.Pattern; +import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.PinotException; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnNotFoundException; +import io.trino.spi.connector.SchemaTableName; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; +import org.apache.pinot.common.request.context.FunctionContext; +import org.apache.pinot.common.request.context.predicate.Predicate; +import org.apache.pinot.common.request.context.predicate.RangePredicate; +import org.apache.pinot.segment.spi.AggregationFunctionType; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.immutableEnumMap; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.pinot.PinotErrorCode.PINOT_EXCEPTION; +import static io.trino.plugin.pinot.PinotErrorCode.PINOT_INVALID_PQL_GENERATED; +import static io.trino.plugin.pinot.query.DynamicTablePqlExtractor.quoteIdentifier; +import static io.trino.plugin.pinot.query.PinotPatterns.WILDCARD; +import static io.trino.plugin.pinot.query.PinotPatterns.aggregationFunction; +import static io.trino.plugin.pinot.query.PinotPatterns.aggregationFunctionType; +import static io.trino.plugin.pinot.query.PinotPatterns.arguments; +import static io.trino.plugin.pinot.query.PinotPatterns.binaryFunction; +import static io.trino.plugin.pinot.query.PinotPatterns.binaryFunctionPredicateValue; +import static io.trino.plugin.pinot.query.PinotPatterns.binaryOperator; +import static io.trino.plugin.pinot.query.PinotPatterns.binaryOperatorValue; +import static io.trino.plugin.pinot.query.PinotPatterns.childFilters; +import static io.trino.plugin.pinot.query.PinotPatterns.expression; +import static io.trino.plugin.pinot.query.PinotPatterns.expressionType; +import static io.trino.plugin.pinot.query.PinotPatterns.filter; +import static io.trino.plugin.pinot.query.PinotPatterns.filterPredicate; +import static io.trino.plugin.pinot.query.PinotPatterns.filterType; +import static io.trino.plugin.pinot.query.PinotPatterns.firstArgument; +import static io.trino.plugin.pinot.query.PinotPatterns.function; +import static io.trino.plugin.pinot.query.PinotPatterns.functionContext; +import static io.trino.plugin.pinot.query.PinotPatterns.identifier; +import static io.trino.plugin.pinot.query.PinotPatterns.predicate; +import static io.trino.plugin.pinot.query.PinotPatterns.predicateExpression; +import static io.trino.plugin.pinot.query.PinotPatterns.predicateType; +import static io.trino.plugin.pinot.query.PinotPatterns.predicateValuesList; +import static io.trino.plugin.pinot.query.PinotPatterns.secondArgument; +import static io.trino.plugin.pinot.query.PinotPatterns.singleInput; +import static io.trino.plugin.pinot.query.PinotPatterns.transformFunction; +import static io.trino.plugin.pinot.query.PinotPatterns.transformFunctionType; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; +import static org.apache.pinot.common.function.TransformFunctionType.CASE; +import static org.apache.pinot.common.function.TransformFunctionType.CAST; +import static org.apache.pinot.common.function.TransformFunctionType.MINUS; +import static org.apache.pinot.common.request.context.ExpressionContext.Type.IDENTIFIER; +import static org.apache.pinot.common.request.context.predicate.RangePredicate.UNBOUNDED; +import static org.apache.pinot.segment.spi.AggregationFunctionType.COUNT; +import static org.apache.pinot.segment.spi.AggregationFunctionType.getAggregationFunctionType; + +public class PinotSqlFormatter +{ + private static final List> FILTER_RULES = ImmutableList.>builder() + .add(new AndOrFilterRule()) + .add(new PredicateFilterRule()) + .build(); + + private static final List> GLOBAL_PREDICATE_RULES = ImmutableList.>builder() + .add(new MinusZeroPredicateRule()) + .add(new BinaryOperatorPredicateRule()) + .build(); + + private static final Map> PREDICATE_RULE_MAP; + private static final Map> FUNCTION_RULE_MAP; + private static final Map> AGGREGATION_FUNCTION_RULE_MAP; + private static final Rule DEFAULT_FUNCTION_RULE = new DefaultFunctionRule(); + + static { + Map> predicateMap = new HashMap<>(); + predicateMap.put(Predicate.Type.IN, new ValuesListPredicateRule(Predicate.Type.IN, "IN")); + predicateMap.put(Predicate.Type.NOT_IN, new ValuesListPredicateRule(Predicate.Type.NOT_IN, "NOT IN")); + predicateMap.put(Predicate.Type.RANGE, new RangePredicateRule()); + predicateMap.put(Predicate.Type.REGEXP_LIKE, new BinaryFunctionPredicateRule(Predicate.Type.REGEXP_LIKE, "regexp_like")); + predicateMap.put(Predicate.Type.TEXT_MATCH, new BinaryFunctionPredicateRule(Predicate.Type.TEXT_MATCH, "text_match")); + predicateMap.put(Predicate.Type.JSON_MATCH, new BinaryFunctionPredicateRule(Predicate.Type.JSON_MATCH, "json_match")); + predicateMap.put(Predicate.Type.IS_NULL, new ExpressionOnlyPredicate(Predicate.Type.IS_NULL, "IS NULL")); + predicateMap.put(Predicate.Type.IS_NOT_NULL, new ExpressionOnlyPredicate(Predicate.Type.IS_NOT_NULL, "IS NOT NULL")); + PREDICATE_RULE_MAP = immutableEnumMap(predicateMap); + + Map> functionMap = new HashMap<>(); + functionMap.put(CASE, new CaseFunctionRule()); + functionMap.put(CAST, new CastFunctionRule()); + functionMap.put(MINUS, new MinusFunctionRule()); + FUNCTION_RULE_MAP = immutableEnumMap(functionMap); + + Map> aggregationFunctionMap = new HashMap<>(); + aggregationFunctionMap.put(COUNT, new CountStarFunctionRule()); + AGGREGATION_FUNCTION_RULE_MAP = immutableEnumMap(aggregationFunctionMap); + } + + private PinotSqlFormatter() {} + + public static String formatFilter(SchemaTableName schemaTableName, FilterContext filterContext, Map columnHandles) + { + requireNonNull(filterContext, "filterContext is null"); + Context context = new Context() { + @Override + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @Override + public Optional> getColumnHandles() + { + return Optional.of(columnHandles); + } + }; + return formatFilter(filterContext, context); + } + + private static String formatFilter(FilterContext filterContext, Context context) + { + Optional result = applyRules(FILTER_RULES, filterContext, context); + if (result.isPresent()) { + return result.get(); + } + throw new PinotException(PINOT_INVALID_PQL_GENERATED, Optional.empty(), format("Unexpected filter type: '%s'", filterContext.getType())); + } + + private static String formatPredicate(Predicate predicate, Context context) + { + Optional result = applyRules(GLOBAL_PREDICATE_RULES, predicate, context); + if (result.isPresent()) { + return result.get(); + } + Rule rule = PREDICATE_RULE_MAP.get(predicate.getType()); + if (rule != null) { + result = applyRule(rule, predicate, context); + } + if (result.isPresent()) { + return result.get(); + } + throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported predicate type '%s'", predicate.getType())); + } + + public static String formatExpression(SchemaTableName schemaTableName, ExpressionContext expressionContext) + { + return formatExpression(schemaTableName, expressionContext, Optional.empty()); + } + + public static String formatExpression(SchemaTableName schemaTableName, ExpressionContext expressionContext, Optional> columnHandles) + { + requireNonNull(expressionContext, "expressionContext is null"); + Context context = new Context() { + @Override + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @Override + public Optional> getColumnHandles() + { + return columnHandles; + } + }; + return formatExpression(expressionContext, context); + } + + private static String formatExpression(ExpressionContext expressionContext, Context context) + { + switch (expressionContext.getType()) { + case LITERAL: + return singleQuoteValue(expressionContext.getLiteral()); + case IDENTIFIER: + if (context.getColumnHandles().isPresent()) { + return quoteIdentifier(getColumnHandle(expressionContext.getIdentifier(), context.getSchemaTableName(), context.getColumnHandles().get()).getColumnName()); + } + return quoteIdentifier(expressionContext.getIdentifier()); + case FUNCTION: + return formatFunction(expressionContext.getFunction(), context); + } + throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported expression type '%s'", expressionContext.getType())); + } + + private static String formatFunction(FunctionContext functionContext, Context context) + { + Optional result = Optional.empty(); + if (functionContext.getType() == FunctionContext.Type.TRANSFORM) { + Rule rule = FUNCTION_RULE_MAP.get(TransformFunctionType.getTransformFunctionType(functionContext.getFunctionName())); + + if (rule != null) { + result = applyRule(rule, functionContext, context); + } + } + else { + checkState(functionContext.getType() == FunctionContext.Type.AGGREGATION, "Unexpected function type for '%s'", functionContext); + Rule rule = AGGREGATION_FUNCTION_RULE_MAP.get(getAggregationFunctionType(functionContext.getFunctionName())); + if (rule != null) { + result = applyRule(rule, functionContext, context); + } + } + if (result.isPresent()) { + return result.get(); + } + result = applyRule(DEFAULT_FUNCTION_RULE, functionContext, context); + if (result.isPresent()) { + return result.get(); + } + throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported function expression '%s'", functionContext)); + } + + private static Optional applyRule(Rule rule, T object, Context context) + { + Iterator iterator = rule.getPattern().match(object).iterator(); + while (iterator.hasNext()) { + Match match = iterator.next(); + return Optional.of(rule.formatToSql(object, match.captures(), context)); + } + return Optional.empty(); + } + + private static Optional applyRules(List> rules, T object, Context context) + { + Optional result = Optional.empty(); + for (Rule rule : rules) { + result = applyRule(rule, object, context); + if (result.isPresent()) { + break; + } + } + return result; + } + + private static String singleQuoteValue(String value) + { + return "'" + value.replaceAll("'", "''") + "'"; + } + + private static String singleQuoteValues(List values) + { + return values.stream() + .map(PinotSqlFormatter::singleQuoteValue) + .collect(joining(", ")); + } + + public static String stripQuotes(String value) + { + if (value.startsWith("'") && value.endsWith("'")) { + return value.substring(1, value.length() - 1); + } + return value; + } + + public static PinotColumnHandle getColumnHandle(String name, SchemaTableName schemaTableName, Map columnHandles) + { + PinotColumnHandle columnHandle = (PinotColumnHandle) columnHandles.get(name); + if (columnHandle == null) { + throw new ColumnNotFoundException(schemaTableName, name); + } + return columnHandle; + } + + private interface Context + { + SchemaTableName getSchemaTableName(); + + Optional> getColumnHandles(); + } + + private interface Rule + { + Pattern getPattern(); + + String formatToSql(T object, Captures captures, Context context); + } + + private static class AndOrFilterRule + implements Rule + { + private static final Capture FILTER_TYPE = newCapture(); + private static final Capture> CHILD_FILTERS = newCapture(); + + private static final Pattern PATTERN = filter() + .with(filterType().matching(contextType -> contextType == FilterContext.Type.AND || contextType == FilterContext.Type.OR)) + .with(filterType().capturedAs(FILTER_TYPE)) + .with(childFilters().capturedAs(CHILD_FILTERS)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FilterContext object, Captures captures, Context context) + { + FilterContext.Type filterType = captures.get(FILTER_TYPE); + List childFilters = captures.get(CHILD_FILTERS); + return format("%s(%s)", filterType.name(), childFilters.stream() + .map(filterContext -> formatFilter(filterContext, context)) + .collect(joining(", "))); + } + } + + private static class PredicateFilterRule + implements Rule + { + private static final Capture PREDICATE = newCapture(); + private static final Pattern PATTERN = filter() + .with(filterPredicate().capturedAs(PREDICATE)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FilterContext object, Captures captures, Context context) + { + Predicate predicate = captures.get(PREDICATE); + return formatPredicate(predicate, context); + } + } + + // Pinot parses predicates like [=|!=|>|<|>=|<=] + // as equals(minus(x, y), 0) which is not valid pql or valid pinot sql. + // These patterns need to be rewritten to x op y here. + private static class MinusZeroPredicateRule + implements Rule + { + private static final Capture FIRST_ARGUMENT = newCapture(); + private static final Capture SECOND_ARGUMENT = newCapture(); + private static final Capture BINARY_OPERATOR_NAME = newCapture(); + private static final Pattern PATTERN = predicate() + .with(binaryOperatorValue().equalTo("0")) + .with(binaryOperator().capturedAs(BINARY_OPERATOR_NAME)) + .with(predicateExpression().matching(expression() + .with(functionContext().matching(binaryFunction() + .with(firstArgument().capturedAs(FIRST_ARGUMENT)) + .with(secondArgument().capturedAs(SECOND_ARGUMENT)) + .with(transformFunctionType().equalTo(MINUS)))))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(Predicate object, Captures captures, Context context) + { + ExpressionContext first = captures.get(FIRST_ARGUMENT); + ExpressionContext second = captures.get(SECOND_ARGUMENT); + String operator = captures.get(BINARY_OPERATOR_NAME); + return format("(%s) %s (%s)", formatExpression(first, context), operator, formatExpression(second, context)); + } + } + + private static class BinaryOperatorPredicateRule + implements Rule + { + private static final Capture BINARY_OPERATOR_NAME = newCapture(); + private static final Capture BINARY_OPERATOR_VALUE = newCapture(); + private static final Capture PREDICATE_EXPRESSION = newCapture(); + private static final Pattern PATTERN = predicate() + .with(binaryOperatorValue().capturedAs(BINARY_OPERATOR_VALUE)) + .with(binaryOperator().capturedAs(BINARY_OPERATOR_NAME)) + .with(predicateExpression().capturedAs(PREDICATE_EXPRESSION)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(Predicate object, Captures captures, Context context) + { + ExpressionContext predicateExpression = captures.get(PREDICATE_EXPRESSION); + String singleValue = captures.get(BINARY_OPERATOR_VALUE); + String operator = captures.get(BINARY_OPERATOR_NAME); + return format("(%s) %s %s", formatExpression(predicateExpression, context), operator, singleQuoteValue(singleValue)); + } + } + + private static class ValuesListPredicateRule + implements Rule + { + private static final Capture> VALUES_LIST = newCapture(); + private static final Capture PREDICATE_EXPRESSION = newCapture(); + private static final Pattern VALUES_LIST_PATTERN = predicate() + .with(predicateValuesList().capturedAs(VALUES_LIST)) + .with(predicateExpression().capturedAs(PREDICATE_EXPRESSION)); + + private final Pattern pattern; + private final String operator; + + public ValuesListPredicateRule(Predicate.Type predicateType, String operator) + { + requireNonNull(predicateType, "predicateType is null"); + this.operator = requireNonNull(operator, "operator is null"); + pattern = VALUES_LIST_PATTERN.with(predicateType().equalTo(predicateType)); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public String formatToSql(Predicate object, Captures captures, Context context) + { + ExpressionContext predicateExpression = captures.get(PREDICATE_EXPRESSION); + List values = captures.get(VALUES_LIST); + return format("%s %s (%s)", formatExpression(predicateExpression, context), operator, singleQuoteValues(values)); + } + } + + private static class RangePredicateRule + implements Rule + { + private static final Capture PREDICATE_EXPRESSION = newCapture(); + + private static final Pattern PATTERN = predicate() + .with(predicateType().equalTo(Predicate.Type.RANGE)) + .with(predicateExpression().capturedAs(PREDICATE_EXPRESSION)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(Predicate object, Captures captures, Context context) + { + RangePredicate rangePredicate = (RangePredicate) object; + ExpressionContext predicateExpression = captures.get(PREDICATE_EXPRESSION); + String expression = formatExpression(predicateExpression, context); + + // Single value range should have been rewritten in formatBinaryOperatorPredicate + checkState(!rangePredicate.getLowerBound().equals(UNBOUNDED) && !rangePredicate.getUpperBound().equals(UNBOUNDED), "Unexpected range predicate '%s'", rangePredicate); + if (rangePredicate.isUpperInclusive() && rangePredicate.isLowerInclusive()) { + return format("(%s) BETWEEN %s AND %s", expression, singleQuoteValue(rangePredicate.getLowerBound()), singleQuoteValue(rangePredicate.getUpperBound())); + } + String leftOperator = rangePredicate.isLowerInclusive() ? ">=" : ">"; + String rightOperator = rangePredicate.isUpperInclusive() ? "<=" : "<"; + return format("(%1$s) %2$s %3$s AND (%1$s) %4$s %5$s", expression, leftOperator, singleQuoteValue(rangePredicate.getLowerBound()), rightOperator, singleQuoteValue(rangePredicate.getUpperBound())); + } + } + + private static class BinaryFunctionPredicateRule + implements Rule + { + private static final Capture BINARY_FUNCTION_VALUE = newCapture(); + private static final Capture PREDICATE_EXPRESSION = newCapture(); + private static final Pattern BINARY_FUNCTION_PREDICATE = predicate() + .with(binaryFunctionPredicateValue().capturedAs(BINARY_FUNCTION_VALUE)) + .with(predicateExpression().capturedAs(PREDICATE_EXPRESSION)); + + private final Pattern pattern; + private final String functionName; + + public BinaryFunctionPredicateRule(Predicate.Type predicateType, String functionName) + { + requireNonNull(predicateType, "predicateType is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); + this.pattern = BINARY_FUNCTION_PREDICATE.with(predicateType().equalTo(predicateType)); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public String formatToSql(Predicate object, Captures captures, Context context) + { + String value = captures.get(BINARY_FUNCTION_VALUE); + ExpressionContext predicateExpression = captures.get(PREDICATE_EXPRESSION); + return format("%s(%s, %s)", functionName, formatExpression(predicateExpression, context), singleQuoteValue(value)); + } + } + + private static class ExpressionOnlyPredicate + implements Rule + { + private static final Capture PREDICATE_EXPRESSION = newCapture(); + private static final Pattern PREDICATE_PATTERN = predicate() + .with(predicateExpression().capturedAs(PREDICATE_EXPRESSION)); + + private final Pattern pattern; + private final String operator; + + public ExpressionOnlyPredicate(Predicate.Type predicateType, String operator) + { + requireNonNull(predicateType, "predicateType is null"); + this.operator = requireNonNull(operator, "operator is null"); + this.pattern = PREDICATE_PATTERN.with(predicateType().equalTo(predicateType)); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public String formatToSql(Predicate object, Captures captures, Context context) + { + ExpressionContext predicateExpression = captures.get(PREDICATE_EXPRESSION); + return format("%s %s", formatExpression(predicateExpression, context), operator); + } + } + + // This is necessary because pinot renders - + // as minus(x, y) which is valid pql for the broker but not valid sql for the pinot parser. + private static class MinusFunctionRule + implements Rule + { + private static final Capture FIRST_ARGUMENT = newCapture(); + private static final Capture SECOND_ARGUMENT = newCapture(); + private static final Pattern PATTERN = binaryFunction() + .with(transformFunctionType().equalTo(MINUS)) + .with(firstArgument().capturedAs(FIRST_ARGUMENT)) + .with(secondArgument().capturedAs(SECOND_ARGUMENT)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FunctionContext object, Captures captures, Context context) + { + ExpressionContext first = captures.get(FIRST_ARGUMENT); + ExpressionContext second = captures.get(SECOND_ARGUMENT); + return format("%s - %s", formatExpression(first, context), formatExpression(second, context)); + } + } + + // Pinot parses cast as a function with the second argument being a literal instead of a type + // The broker request parses it this way, so the reverse needs to be done here + private static class CastFunctionRule + implements Rule + { + private static final Capture FIRST_ARGUMENT = newCapture(); + private static final Capture SECOND_ARGUMENT = newCapture(); + private static final Pattern PATTERN = binaryFunction() + .with(transformFunctionType().equalTo(CAST)) + .with(firstArgument().capturedAs(FIRST_ARGUMENT)) + .with(secondArgument().capturedAs(SECOND_ARGUMENT)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FunctionContext object, Captures captures, Context context) + { + ExpressionContext first = captures.get(FIRST_ARGUMENT); + // Pinot interprets the second argument as a literal instead of a type + ExpressionContext second = captures.get(SECOND_ARGUMENT); + return format("CAST(%s AS %s)", formatExpression(first, context), stripQuotes(formatExpression(second, context))); + } + } + + // Pinot parses case statements as a function case(,... , ,... , ) + // This is valid pql for the pinot broker but not valid sql for the pinot sql parser, so this needs to be rewritten here. + private static class CaseFunctionRule + implements Rule + { + private static final Capture> ARGUMENTS = newCapture(); + private static final Pattern PATTERN = transformFunction() + .with(transformFunctionType().equalTo(CASE)) + .with(arguments().capturedAs(ARGUMENTS)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FunctionContext object, Captures captures, Context context) + { + List arguments = captures.get(ARGUMENTS).stream() + .map(expressionContext -> formatExpression(expressionContext, context)) + .collect(toImmutableList()); + checkState(arguments.size() >= 2, "Unexpected expression '%s'", object); + int whenStatements = arguments.size() / 2; + StringBuilder builder = new StringBuilder("CASE "); + builder.append("WHEN ") + .append(arguments.get(0)) + .append(" THEN ") + .append(arguments.get(whenStatements)); + + for (int index = 1; index < whenStatements; index++) { + builder.append(" WHEN ") + .append(arguments.get(index)) + .append(" THEN ") + .append(arguments.get(index + whenStatements)); + } + + if (arguments.size() % 2 != 0) { + builder.append(" ELSE ") + .append(arguments.get(arguments.size() - 1)); + } + return builder.append(" END").toString(); + } + } + + private static class CountStarFunctionRule + implements Rule + { + private static final Pattern PATTERN = aggregationFunction() + .with(aggregationFunctionType().equalTo(COUNT)) + .with(singleInput().matching(expression() + .with(expressionType().equalTo(IDENTIFIER)) + .with(identifier().equalTo(WILDCARD)))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FunctionContext object, Captures captures, Context context) + { + return format("%s(%s)", object.getFunctionName(), WILDCARD); + } + } + + private static class DefaultFunctionRule + implements Rule + { + private static final Capture> ARGUMENTS = newCapture(); + private static final Pattern PATTERN = function() + .with(arguments().capturedAs(ARGUMENTS)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public String formatToSql(FunctionContext object, Captures captures, Context context) + { + return format("%s(%s)", object.getFunctionName(), captures.get(ARGUMENTS).stream() + .map(expressionContext -> formatExpression(expressionContext, context)) + .collect(joining(", "))); + } + } +} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotTypeResolver.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotTypeResolver.java new file mode 100644 index 000000000000..2f4e3a023f65 --- /dev/null +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotTypeResolver.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.pinot.query; + +import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.PinotException; +import io.trino.plugin.pinot.client.PinotClient; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnNotFoundException; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.operator.transform.TransformResultMetadata; +import org.apache.pinot.core.operator.transform.function.LiteralTransformFunction; +import org.apache.pinot.core.operator.transform.function.TransformFunctionFactory; +import org.apache.pinot.segment.local.segment.index.datasource.EmptyDataSource; +import org.apache.pinot.segment.spi.datasource.DataSource; +import org.apache.pinot.segment.spi.index.metadata.ColumnMetadataImpl; +import org.apache.pinot.spi.data.FieldSpec; + +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.plugin.pinot.PinotErrorCode.PINOT_INVALID_PQL_GENERATED; +import static io.trino.plugin.pinot.PinotErrorCode.PINOT_UNSUPPORTED_COLUMN_TYPE; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class PinotTypeResolver +{ + private final Map datasourceMap; + + public PinotTypeResolver(PinotClient pinotClient, String pinotTableName) + { + requireNonNull(pinotClient, "pinotClient is null"); + this.datasourceMap = getDataSourceMap(pinotClient, pinotTableName); + } + + private static Map getDataSourceMap(PinotClient pinotClient, String pinotTableName) + { + try { + return pinotClient.getTableSchema(pinotTableName).getFieldSpecMap().entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, + entry -> new EmptyDataSource(new ColumnMetadataImpl.Builder() + .setFieldSpec(entry.getValue()) + .build()))); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + public TransformResultMetadata resolveExpressionType(ExpressionContext expression, SchemaTableName schemaTableName, Map columnHandles) + { + switch (expression.getType()) { + case IDENTIFIER: + PinotColumnHandle columnHandle = (PinotColumnHandle) columnHandles.get(expression.getIdentifier().toLowerCase(ENGLISH)); + if (columnHandle == null) { + throw new ColumnNotFoundException(schemaTableName, expression.getIdentifier()); + } + return fromTrinoType(columnHandle.getDataType()); + case FUNCTION: + return TransformFunctionFactory.get(expression, datasourceMap).getResultMetadata(); + case LITERAL: + FieldSpec.DataType literalDataType = LiteralTransformFunction.inferLiteralDataType(new LiteralTransformFunction(expression.getLiteral())); + return new TransformResultMetadata(literalDataType, true, false); + default: + throw new PinotException(PINOT_INVALID_PQL_GENERATED, Optional.empty(), format("Unsupported expression: '%s'", expression)); + } + } + + public static TransformResultMetadata fromTrinoType(Type type) + { + if (type instanceof ArrayType) { + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + return new TransformResultMetadata(fromPrimitiveTrinoType(elementType), false, false); + } + else { + return new TransformResultMetadata(fromPrimitiveTrinoType(type), true, false); + } + } + + private static FieldSpec.DataType fromPrimitiveTrinoType(Type type) + { + if (type instanceof VarcharType) { + return FieldSpec.DataType.STRING; + } + if (type instanceof BigintType) { + return FieldSpec.DataType.LONG; + } + if (type instanceof IntegerType) { + return FieldSpec.DataType.INT; + } + if (type instanceof DoubleType) { + return FieldSpec.DataType.DOUBLE; + } + if (type instanceof RealType) { + return FieldSpec.DataType.FLOAT; + } + if (type instanceof BooleanType) { + return FieldSpec.DataType.BOOLEAN; + } + if (type instanceof VarbinaryType) { + return FieldSpec.DataType.BYTES; + } + throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "Unsupported column data type: " + type); + } +} diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementApproxDistinct.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementApproxDistinct.java index 168869844906..87be6b436cd6 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementApproxDistinct.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementApproxDistinct.java @@ -17,7 +17,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -30,10 +30,9 @@ import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; import static io.trino.spi.type.BigintType.BIGINT; -import static java.lang.String.format; public class ImplementApproxDistinct - implements AggregateFunctionRule + implements AggregateFunctionRule { // Extracted from io.trino.plugin.jdbc.expression private static final Capture INPUT = newCapture(); @@ -48,9 +47,9 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable input = captures.get(INPUT); - return Optional.of(new PinotColumnHandle(format("distinctcounthll(%s)", context.getIdentifierQuote().apply(input.getName())), aggregateFunction.getOutputType(), false)); + return Optional.of(new AggregateExpression("distinctcounthll", context.getIdentifierQuote().apply(input.getName()), false)); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementAvg.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementAvg.java index 9b86443fb344..2bb6407eea37 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementAvg.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementAvg.java @@ -18,7 +18,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; @@ -36,10 +36,9 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; -import static java.lang.String.format; public class ImplementAvg - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture INPUT = newCapture(); private static final Set SUPPORTED_INPUT_TYPES = ImmutableSet.of(INTEGER, BIGINT, REAL, DOUBLE); @@ -56,9 +55,9 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable input = captures.get(INPUT); - return Optional.of(new PinotColumnHandle(format("avg(%s)", context.getIdentifierQuote().apply(input.getName())), aggregateFunction.getOutputType())); + return Optional.of(new AggregateExpression(aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(input.getName()), true)); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountAll.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountAll.java index d84835b52343..225a58128237 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountAll.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountAll.java @@ -16,7 +16,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.spi.connector.AggregateFunction; import java.util.List; @@ -32,7 +32,7 @@ * Implements {@code count(*)}. */ public class ImplementCountAll - implements AggregateFunctionRule + implements AggregateFunctionRule { @Override public Pattern getPattern() @@ -44,8 +44,8 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { - return Optional.of(new PinotColumnHandle("count(*)", aggregateFunction.getOutputType(), false)); + return Optional.of(new AggregateExpression("count", "*", false)); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountDistinct.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountDistinct.java index fc45976db9b2..beba0855ff89 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountDistinct.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementCountDistinct.java @@ -17,7 +17,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -32,10 +32,9 @@ import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; import static io.trino.plugin.pinot.PinotSessionProperties.isCountDistinctPushdownEnabled; import static io.trino.spi.type.BigintType.BIGINT; -import static java.lang.String.format; public class ImplementCountDistinct - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture INPUT = newCapture(); @@ -49,13 +48,13 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { if (!isCountDistinctPushdownEnabled(context.getSession())) { return Optional.empty(); } Variable input = captures.get(INPUT); verify(aggregateFunction.getOutputType() == BIGINT); - return Optional.of(new PinotColumnHandle(format("distinctcount(%s)", context.getIdentifierQuote().apply(input.getName())), aggregateFunction.getOutputType(), false)); + return Optional.of(new AggregateExpression("distinctcount", context.getIdentifierQuote().apply(input.getName()), false)); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementMinMax.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementMinMax.java index 7541078d3398..65b93d24da98 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementMinMax.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementMinMax.java @@ -19,6 +19,7 @@ import io.trino.matching.Pattern; import io.trino.plugin.base.expression.AggregateFunctionRule; import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; @@ -37,13 +38,12 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; -import static java.lang.String.format; /** * Implements {@code min(x)}, {@code max(x)}. */ public class ImplementMinMax - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture INPUT = newCapture(); private static final Set SUPPORTED_INPUT_TYPES = ImmutableSet.of(INTEGER, BIGINT, REAL, DOUBLE); @@ -60,11 +60,11 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable input = captures.get(INPUT); PinotColumnHandle columnHandle = (PinotColumnHandle) context.getAssignment(input.getName()); verify(columnHandle.getDataType().equals(aggregateFunction.getOutputType())); - return Optional.of(new PinotColumnHandle(format("%s(%s)", aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName())), aggregateFunction.getOutputType())); + return Optional.of(new AggregateExpression(aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName()), true)); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementSum.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementSum.java index 9070185c8e3e..a690398e9278 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementSum.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/aggregation/ImplementSum.java @@ -19,6 +19,7 @@ import io.trino.matching.Pattern; import io.trino.plugin.base.expression.AggregateFunctionRule; import io.trino.plugin.pinot.PinotColumnHandle; +import io.trino.plugin.pinot.query.AggregateExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; @@ -36,13 +37,12 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; -import static java.lang.String.format; /** * Implements {@code sum(x)} */ public class ImplementSum - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture INPUT = newCapture(); private static final Set SUPPORTED_INPUT_TYPES = ImmutableSet.of(INTEGER, BIGINT, REAL, DOUBLE); @@ -59,10 +59,10 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable input = captures.get(INPUT); PinotColumnHandle columnHandle = (PinotColumnHandle) context.getAssignment(input.getName()); - return Optional.of(new PinotColumnHandle(format("sum(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), aggregateFunction.getOutputType())); + return Optional.of(new AggregateExpression(aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName()), true)); } } diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java index 5e025a7a202a..668a6c33d775 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java @@ -17,7 +17,7 @@ import io.trino.plugin.pinot.client.PinotClient; import io.trino.plugin.pinot.client.PinotClient.BrokerResultRow; import io.trino.plugin.pinot.client.PinotClient.ResultsIterator; -import io.trino.plugin.pinot.query.PinotQuery; +import io.trino.plugin.pinot.query.PinotQueryInfo; import io.trino.spi.Page; import io.trino.spi.block.Block; import org.apache.pinot.common.response.broker.BrokerResponseNative; @@ -109,7 +109,7 @@ public void testBrokerQuery() .add(new PinotColumnHandle("col_3", VARCHAR)) .build(); PinotBrokerPageSource pageSource = new PinotBrokerPageSource(createSessionWithNumSplits(1, false, pinotConfig), - new PinotQuery("test_table", "SELECT col_1, col_2, col_3 FROM test_table", 0), + new PinotQueryInfo("test_table", "SELECT col_1, col_2, col_3 FROM test_table", 0), columnHandles, testingPinotClient, LIMIT_FOR_BROKER_QUERIES); @@ -131,7 +131,7 @@ public void testBrokerQuery() public void testCountStarBrokerQuery() { PinotBrokerPageSource pageSource = new PinotBrokerPageSource(createSessionWithNumSplits(1, false, pinotConfig), - new PinotQuery("test_table", "SELECT COUNT(*) FROM test_table", 0), + new PinotQueryInfo("test_table", "SELECT COUNT(*) FROM test_table", 0), ImmutableList.of(), testingPinotClient, LIMIT_FOR_BROKER_QUERIES); @@ -163,7 +163,7 @@ public void testBrokerResponseHasTooManyRows() .add(new PinotColumnHandle("col_3", VARCHAR)) .build(); PinotBrokerPageSource pageSource = new PinotBrokerPageSource(createSessionWithNumSplits(1, false, pinotConfig), - new PinotQuery("test_table", "SELECT col_1, col_2, col_3 FROM test_table", 0), + new PinotQueryInfo("test_table", "SELECT col_1, col_2, col_3 FROM test_table", 0), columnHandles, testingPinotClient, LIMIT_FOR_BROKER_QUERIES); diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestDynamicTable.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestDynamicTable.java index 16a804b3db50..7123cb59b22c 100755 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestDynamicTable.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestDynamicTable.java @@ -28,17 +28,16 @@ import java.util.List; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.pinot.query.DynamicTableBuilder.OFFLINE_SUFFIX; import static io.trino.plugin.pinot.query.DynamicTableBuilder.REALTIME_SUFFIX; import static io.trino.plugin.pinot.query.DynamicTableBuilder.buildFromPql; import static io.trino.plugin.pinot.query.DynamicTablePqlExtractor.extractPql; -import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; import static org.testng.Assert.assertEquals; public class TestDynamicTable @@ -52,7 +51,7 @@ public void testSelectNoFilter() List orderByColumns = columnNames.subList(0, 5); List orderByExpressions = orderByColumns.stream() .limit(4) - .map(columnName -> new OrderByExpression(columnName, true)) + .map(columnName -> new OrderByExpression(quoteIdentifier(columnName), true)) .collect(toList()); long limit = 230; String query = format("select %s from %s order by %s limit %s", @@ -62,9 +61,12 @@ public void testSelectNoFilter() orderByColumns.stream() .collect(joining(", ")) + " desc", limit); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); - assertEquals(dynamicTable.getSelections(), columnNames); - orderByExpressions.add(new OrderByExpression(orderByColumns.get(4), false)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + assertEquals(dynamicTable.getProjections().stream() + .map(PinotColumnHandle::getColumnName) + .collect(toImmutableList()), + columnNames); + orderByExpressions.add(new OrderByExpression(quoteIdentifier(orderByColumns.get(4)), false)); assertEquals(dynamicTable.getOrderBy(), orderByExpressions); assertEquals(dynamicTable.getLimit().getAsLong(), limit); } @@ -75,15 +77,14 @@ public void testGroupBy() String tableName = realtimeOnlyTable.getTableName(); long limit = 25; String query = format("SELECT Origin, AirlineID, max(CarrierDelay), avg(CarrierDelay) FROM %s GROUP BY Origin, AirlineID LIMIT %s", tableName, limit); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); - assertEquals(dynamicTable.getGroupingColumns(), ImmutableList.builder() - .add("Origin") - .add("AirlineID") - .build()); - assertEquals(dynamicTable.getAggregateColumns(), ImmutableList.builder() - .add(new PinotColumnHandle("max(carrierdelay)", DOUBLE)) - .add(new PinotColumnHandle("avg(carrierdelay)", DOUBLE)) - .build()); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + assertEquals(dynamicTable.getGroupingColumns().stream() + .map(PinotColumnHandle::getColumnName) + .collect(toImmutableList()), + ImmutableList.builder() + .add("Origin") + .add("AirlineID") + .build()); assertEquals(dynamicTable.getLimit().getAsLong(), limit); } @@ -97,11 +98,11 @@ public void testFilter() "OR ((DepDelayMinutes < 10) AND (Distance >= 3) AND (ArrDelay > 4) AND (SecurityDelay < 5) AND (LateAircraftDelay <= 7)) limit 60", tableName.toLowerCase(ENGLISH)); - String expected = format("select \"FlightNum\", \"AirlineID\" from %s where OR(AND(\"CancellationCode\" IN ('strike', 'weather', 'pilot_bac'), (\"Origin\") = ('jfk'))," + - " AND((\"OriginCityName\") != ('catfish paradise'), (\"OriginState\") != ('az'), BETWEEN(\"AirTime\", 1, 5), \"AirTime\" NOT IN (7, 8, 9))," + - " AND((\"DepDelayMinutes\") < (10), (\"Distance\") >= (3), (\"ArrDelay\") > (4), (\"SecurityDelay\") < (5), (\"LateAircraftDelay\") <= (7))) limit 60", + String expected = format("select \"FlightNum\", \"AirlineID\" from %s where OR(AND(\"CancellationCode\" IN ('strike', 'weather', 'pilot_bac'), (\"Origin\") = 'jfk'), " + + "AND((\"OriginCityName\") != 'catfish paradise', (\"OriginState\") != 'az', (\"AirTime\") BETWEEN '1' AND '5', \"AirTime\" NOT IN ('7', '8', '9')), " + + "AND((\"DepDelayMinutes\") < '10', (\"Distance\") >= '3', (\"ArrDelay\") > '4', (\"SecurityDelay\") < '5', (\"LateAircraftDelay\") <= '7')) limit 60", tableName.toLowerCase(ENGLISH)); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expected); } @@ -118,21 +119,21 @@ public void testPrimitiveTypes() " FROM " + tableName + " WHERE string_col = 'string' AND long_col = 12345678901 AND int_col = 123456789" + " AND double_col = 3.56 AND float_col = 3.56 AND bytes_col = 'abcd' LIMIT 60"; String expected = "select \"string_col\", \"long_col\", \"int_col\", \"bool_col\", \"double_col\", \"float_col\", \"bytes_col\"" + - " from primitive_types_table where AND((\"string_col\") = ('string'), (\"long_col\") = (12345678901)," + - " (\"int_col\") = (123456789), (\"double_col\") = (3.56)," + - " (\"float_col\") = (3.56), (\"bytes_col\") = ('abcd')) limit 60"; - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + " from primitive_types_table where AND((\"string_col\") = 'string', (\"long_col\") = '12345678901'," + + " (\"int_col\") = '123456789', (\"double_col\") = '3.56', (\"float_col\") = '3.56', (\"bytes_col\") = 'abcd') limit 60"; + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expected); } @Test public void testDoubleWithScientificNotation() { - // Pinot does not recognize double literals with scientific notation + // Pinot recognizes double literals with scientific notation as of version 0.8.0 String tableName = "primitive_types_table"; String query = "SELECT string_col FROM " + tableName + " WHERE double_col = 3.5E5"; - assertThatNullPointerException().isThrownBy(() -> buildFromPql(pinotMetadata, new SchemaTableName("default", query))) - .withMessage(Runtime.version().feature() < 15 ? null : "Cannot invoke \"java.lang.Integer.intValue()\" because \"this.scale\" is null"); + String expected = "select \"string_col\" from primitive_types_table where (\"double_col\") = '350000.0' limit 10"; + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expected); } @Test @@ -141,9 +142,9 @@ public void testFilterWithCast() String tableName = "primitive_types_table"; String query = "SELECT string_col, long_col" + " FROM " + tableName + " WHERE string_col = CAST(123 AS STRING) AND long_col = CAST('123' AS LONG) LIMIT 60"; - String expected = "select \"string_col\", \"long_col\" from primitive_types_table" + - " where AND((\"string_col\") = (CAST(123 AS string)), (\"long_col\") = (CAST('123' AS long))) limit 60"; - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + String expected = "select \"string_col\", \"long_col\" from primitive_types_table " + + "where AND((\"string_col\") = (CAST('123' AS string)), (\"long_col\") = (CAST('123' AS long))) limit 60"; + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expected); } @@ -155,12 +156,12 @@ public void testFilterWithCaseStatements() "where case when cancellationcode = 'strike' then 3 else 4 end != 5 " + "AND case origincityname when 'nyc' then 'pizza' when 'la' then 'burrito' when 'boston' then 'clam chowder' " + "else 'burger' end != 'salad'", tableName.toLowerCase(ENGLISH)); - String expected = format("select \"FlightNum\", \"AirlineID\" from %s where AND((CASE WHEN (\"CancellationCode\") = ('strike')" + - " THEN 3 ELSE 4 END) != (5), (CASE WHEN (\"OriginCityName\") = ('nyc') THEN 'pizza'" + - " WHEN (\"OriginCityName\") = ('la') THEN 'burrito' WHEN (\"OriginCityName\") = ('boston') THEN 'clam chowder'" + - " ELSE 'burger' END) != ('salad')) limit 10", + String expected = format("select \"FlightNum\", \"AirlineID\" from %s where AND((CASE WHEN equals(\"CancellationCode\", 'strike') " + + "THEN '3' ELSE '4' END) != '5', (CASE WHEN equals(\"OriginCityName\", 'nyc') " + + "THEN 'pizza' WHEN equals(\"OriginCityName\", 'la') THEN 'burrito' WHEN equals(\"OriginCityName\", 'boston') " + + "THEN 'clam chowder' ELSE 'burger' END) != 'salad') limit 10", tableName.toLowerCase(ENGLISH)); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expected); } @@ -169,7 +170,7 @@ public void testFilterWithPushdownConstraint() { String tableName = realtimeOnlyTable.getTableName(); String query = format("select FlightNum from %s limit 60", tableName.toLowerCase(ENGLISH)); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); PinotColumnHandle columnHandle = new PinotColumnHandle("OriginCityName", VARCHAR); TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.builder() .put(columnHandle, @@ -185,9 +186,9 @@ public void testFilterWithPushdownConstraint() public void testFilterWithUdf() { String tableName = realtimeOnlyTable.getTableName(); - String query = format("select FlightNum from %s where DivLongestGTimes = POW(3, 2) limit 60", tableName.toLowerCase(ENGLISH)); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); - String expectedPql = "select \"FlightNum\" from realtimeonly where (\"DivLongestGTimes\") = (POW(3, 2)) limit 60"; + String query = format("select FlightNum from %s where DivLongestGTimes = FLOOR(EXP(2 * LN(3))) AND 5 < EXP(CarrierDelay) limit 60", tableName.toLowerCase(ENGLISH)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = "select \"FlightNum\" from realtimeonly where AND((\"DivLongestGTimes\") = '9.0', (exp(\"CarrierDelay\")) > '5') limit 60"; assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); } @@ -196,7 +197,7 @@ public void testSelectStarDynamicTable() { String tableName = realtimeOnlyTable.getTableName(); String query = format("select * from %s limit 70", tableName.toLowerCase(ENGLISH)); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); String expectedPql = format("select %s from %s limit 70", getColumnNames(tableName).stream().map(TestDynamicTable::quoteIdentifier).collect(joining(", ")), tableName.toLowerCase(ENGLISH)); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); } @@ -207,7 +208,7 @@ public void testOfflineDynamicTable() String tableName = hybridTable.getTableName(); String tableNameWithSuffix = tableName + OFFLINE_SUFFIX; String query = format("select * from %s limit 70", tableNameWithSuffix); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); String expectedPql = format("select %s from %s limit 70", getColumnNames(tableName).stream().map(TestDynamicTable::quoteIdentifier).collect(joining(", ")), tableNameWithSuffix); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); assertEquals(dynamicTable.getTableName(), tableName); @@ -219,14 +220,170 @@ public void testRealtimeOnlyDynamicTable() String tableName = hybridTable.getTableName(); String tableNameWithSuffix = tableName + REALTIME_SUFFIX; String query = format("select * from %s limit 70", tableNameWithSuffix); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query)); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); String expectedPql = format("select %s from %s limit 70", getColumnNames(tableName).stream().map(TestDynamicTable::quoteIdentifier).collect(joining(", ")), tableNameWithSuffix); assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); assertEquals(dynamicTable.getTableName(), tableName); } + @Test + public void testLimitAndOffset() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select * from %s limit 70, 40", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select %s from %s limit 70, 40", getColumnNames(tableName).stream().map(TestDynamicTable::quoteIdentifier).collect(joining(", ")), tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + private static String quoteIdentifier(String identifier) { - return "\"" + identifier + "\""; + return "\"" + identifier.replaceAll("\"", "\"\"") + "\""; + } + + @Test + public void testRegexpLike() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select origincityname from %s where regexp_like(origincityname, '.*york.*') limit 70", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select \"OriginCityName\" from %s where regexp_like(\"OriginCityName\", '.*york.*') limit 70", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testTextMatch() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select origincityname from %s where text_match(origincityname, 'new AND york') limit 70", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select \"OriginCityName\" from %s where text_match(\"OriginCityName\", 'new and york') limit 70", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testJsonMatch() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select origincityname from %s where json_match(origincityname, '\"$.name\"=''new york''') limit 70", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select \"OriginCityName\" from %s where json_match(\"OriginCityName\", '\"$.name\"=''new york''') limit 70", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testSelectExpressionsWithAliases() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select datetimeconvert(dayssinceEpoch, '1:seconds:epoch', '1:milliseconds:epoch', '15:minutes'), " + + "case origincityname when 'nyc' then 'pizza' when 'la' then 'burrito' when 'boston' then 'clam chowder'" + + " else 'burger' end != 'salad'," + + " timeconvert(dayssinceEpoch, 'seconds', 'minutes') as foo" + + " from %s limit 70", tableNameWithSuffix); + + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select datetimeconvert(\"DaysSinceEpoch\", '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '15:MINUTES')," + + " not_equals(CASE WHEN equals(\"OriginCityName\", 'nyc') THEN 'pizza' WHEN equals(\"OriginCityName\", 'la') THEN 'burrito' WHEN equals(\"OriginCityName\", 'boston') THEN 'clam chowder' ELSE 'burger' END, 'salad')," + + " timeconvert(\"DaysSinceEpoch\", 'SECONDS', 'MINUTES') AS \"foo\" from %s limit 70", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testAggregateExpressionsWithAliases() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select datetimeconvert(dayssinceEpoch, '1:seconds:epoch', '1:milliseconds:epoch', '15:minutes'), " + + " count(*) as bar," + + " case origincityname when 'nyc' then 'pizza' when 'la' then 'burrito' when 'boston' then 'clam chowder'" + + " else 'burger' end != 'salad'," + + " timeconvert(dayssinceEpoch, 'seconds', 'minutes') as foo," + + " max(airtime) as baz" + + " from %s group by 1, 3, 4 limit 70", tableNameWithSuffix); + + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select datetimeconvert(\"DaysSinceEpoch\", '1:SECONDS:EPOCH'," + + " '1:MILLISECONDS:EPOCH', '15:MINUTES'), count(*) AS \"bar\"," + + " not_equals(CASE WHEN equals(\"OriginCityName\", 'nyc') THEN 'pizza' WHEN equals(\"OriginCityName\", 'la') THEN 'burrito'" + + " WHEN equals(\"OriginCityName\", 'boston') THEN 'clam chowder' ELSE 'burger' END, 'salad')," + + " timeconvert(\"DaysSinceEpoch\", 'SECONDS', 'MINUTES') AS \"foo\"," + + " max(\"AirTime\") AS \"baz\"" + + " from %s" + + " group by datetimeconvert(\"DaysSinceEpoch\", '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '15:MINUTES')," + + " not_equals(CASE WHEN equals(\"OriginCityName\", 'nyc') THEN 'pizza' WHEN equals(\"OriginCityName\", 'la') THEN 'burrito' WHEN equals(\"OriginCityName\", 'boston') THEN 'clam chowder' ELSE 'burger' END, 'salad')," + + " timeconvert(\"DaysSinceEpoch\", 'SECONDS', 'MINUTES')" + + " limit 70", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testOrderBy() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select ArrDelay + 34 - DaysSinceEpoch, FlightNum from %s order by ArrDelay asc, DaysSinceEpoch desc", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select plus(\"ArrDelay\", '34') - \"DaysSinceEpoch\", \"FlightNum\" from %s order by \"ArrDelay\", \"DaysSinceEpoch\" desc limit 10", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testOrderByCountStar() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select count(*) from %s order by count(*)", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select count(*) from %s order by count(*) limit 10", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testOrderByExpression() + { + String tableName = hybridTable.getTableName(); + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select ArrDelay + 34 - DaysSinceEpoch, FlightNum from %s order by ArrDelay + 34 - DaysSinceEpoch desc", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select plus(\"ArrDelay\", '34') - \"DaysSinceEpoch\", \"FlightNum\" from %s order by plus(\"ArrDelay\", '34') - \"DaysSinceEpoch\" desc limit 10", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testQuotesInAlias() + { + String tableName = "quotes_in_column_names"; + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select non_quoted AS \"non\"\"quoted\" from %s limit 50", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select \"non_quoted\" AS \"non\"\"quoted\" from %s limit 50", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); + } + + @Test + public void testQuotesInColumnName() + { + String tableName = "quotes_in_column_names"; + String tableNameWithSuffix = tableName + REALTIME_SUFFIX; + String query = format("select \"qu\"\"ot\"\"ed\" from %s limit 50", tableNameWithSuffix); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, new SchemaTableName("default", query), mockClusterInfoFetcher); + String expectedPql = format("select \"qu\"\"ot\"\"ed\" from %s limit 50", tableNameWithSuffix); + assertEquals(extractPql(dynamicTable, TupleDomain.all(), ImmutableList.of()), expectedPql); + assertEquals(dynamicTable.getTableName(), tableName); } } diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotIntegrationSmokeTest.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotIntegrationSmokeTest.java index a5e3e57bb45b..546db5ed70ee 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotIntegrationSmokeTest.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotIntegrationSmokeTest.java @@ -73,15 +73,18 @@ public class TestPinotIntegrationSmokeTest // If a broker query does not supply a limit, pinot defaults to 10 rows private static final int DEFAULT_PINOT_LIMIT_FOR_BROKER_QUERIES = 10; private static final String ALL_TYPES_TABLE = "alltypes"; + private static final String DATE_TIME_FIELDS_TABLE = "date_time_fields"; private static final String MIXED_CASE_COLUMN_NAMES_TABLE = "mixed_case"; private static final String MIXED_CASE_DISTINCT_TABLE = "mixed_case_distinct"; private static final String TOO_MANY_ROWS_TABLE = "too_many_rows"; private static final String TOO_MANY_BROKER_ROWS_TABLE = "too_many_broker_rows"; private static final String JSON_TABLE = "my_table"; private static final String RESERVED_KEYWORD_TABLE = "reserved_keyword"; - + private static final String QUOTES_IN_COLUMN_NAME_TABLE = "quotes_in_column_name"; // Use a recent value for updated_at to ensure Pinot doesn't clean up records older than retentionTimeValue as defined in the table specs private static final Instant initialUpdatedAt = Instant.now().minus(Duration.ofDays(1)).truncatedTo(SECONDS); + // Use a fixed instant for testing date time functions + private static final Instant CREATED_AT_INSTANT = Instant.parse("2021-05-10T00:00:00.00Z"); @Override protected QueryRunner createQueryRunner() @@ -101,8 +104,8 @@ protected QueryRunner createQueryRunner() allTypesRecordsBuilder.add(new ProducerRecord<>(ALL_TYPES_TABLE, "key" + i * step, createTestRecord( Arrays.asList("string_" + (offset), "string1_" + (offset + 1), "string2_" + (offset + 2)), - Arrays.asList(false, true, true), - Arrays.asList(54, -10001, 1000), + true, + Arrays.asList(54 + i / 3, -10001, 1000), Arrays.asList(-7.33F + i, Float.POSITIVE_INFINITY, 17.034F + i), Arrays.asList(-17.33D + i, Double.POSITIVE_INFINITY, 10596.034D + i), Arrays.asList(-3147483647L + i, 12L - i, 4147483647L + i), @@ -220,6 +223,34 @@ protected QueryRunner createQueryRunner() pinot.createSchema(getClass().getClassLoader().getResourceAsStream("too_many_broker_rows_schema.json"), TOO_MANY_BROKER_ROWS_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("too_many_broker_rows_realtimeSpec.json"), TOO_MANY_BROKER_ROWS_TABLE); + // Create and populate date time fields table and topic + kafka.createTopic(DATE_TIME_FIELDS_TABLE); + Schema dateTimeFieldsAvroSchema = SchemaBuilder.record(DATE_TIME_FIELDS_TABLE).fields() + .name("string_col").type().stringType().noDefault() + .name("created_at").type().longType().noDefault() + .name("updated_at").type().longType().noDefault() + .endRecord(); + List> dateTimeFieldsProducerRecords = ImmutableList.>builder() + .add(new ProducerRecord<>(DATE_TIME_FIELDS_TABLE, "string_0", new GenericRecordBuilder(dateTimeFieldsAvroSchema) + .set("string_col", "string_0") + .set("created_at", CREATED_AT_INSTANT.toEpochMilli()) + .set("updated_at", initialUpdatedAt.toEpochMilli()) + .build())) + .add(new ProducerRecord<>(DATE_TIME_FIELDS_TABLE, "string_1", new GenericRecordBuilder(dateTimeFieldsAvroSchema) + .set("string_col", "string_1") + .set("created_at", CREATED_AT_INSTANT.plusMillis(1000).toEpochMilli()) + .set("updated_at", initialUpdatedAt.plusMillis(1000).toEpochMilli()) + .build())) + .add(new ProducerRecord<>(DATE_TIME_FIELDS_TABLE, "string_2", new GenericRecordBuilder(dateTimeFieldsAvroSchema) + .set("string_col", "string_2") + .set("created_at", CREATED_AT_INSTANT.plusMillis(2000).toEpochMilli()) + .set("updated_at", initialUpdatedAt.plusMillis(2000).toEpochMilli()) + .build())) + .build(); + kafka.sendMessages(dateTimeFieldsProducerRecords.stream(), schemaRegistryAwareProducer(kafka)); + pinot.createSchema(getClass().getClassLoader().getResourceAsStream("date_time_fields_schema.json"), DATE_TIME_FIELDS_TABLE); + pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("date_time_fields_realtimeSpec.json"), DATE_TIME_FIELDS_TABLE); + // Create json table kafka.createTopic(JSON_TABLE); long key = 0L; @@ -239,15 +270,29 @@ protected QueryRunner createQueryRunner() kafka.createTopic(RESERVED_KEYWORD_TABLE); Schema reservedKeywordAvroSchema = SchemaBuilder.record(RESERVED_KEYWORD_TABLE).fields() .name("date").type().optional().stringType() + .name("as").type().optional().stringType() .name("updatedAt").type().optional().longType() .endRecord(); ImmutableList.Builder> reservedKeywordRecordsBuilder = ImmutableList.builder(); - reservedKeywordRecordsBuilder.add(new ProducerRecord<>(RESERVED_KEYWORD_TABLE, "key0", new GenericRecordBuilder(reservedKeywordAvroSchema).set("date", "2021-09-30").set("updatedAt", initialUpdatedAt.plusMillis(1000).toEpochMilli()).build())); - reservedKeywordRecordsBuilder.add(new ProducerRecord<>(RESERVED_KEYWORD_TABLE, "key1", new GenericRecordBuilder(reservedKeywordAvroSchema).set("date", "2021-10-01").set("updatedAt", initialUpdatedAt.plusMillis(2000).toEpochMilli()).build())); + reservedKeywordRecordsBuilder.add(new ProducerRecord<>(RESERVED_KEYWORD_TABLE, "key0", new GenericRecordBuilder(reservedKeywordAvroSchema).set("date", "2021-09-30").set("as", "foo").set("updatedAt", initialUpdatedAt.plusMillis(1000).toEpochMilli()).build())); + reservedKeywordRecordsBuilder.add(new ProducerRecord<>(RESERVED_KEYWORD_TABLE, "key1", new GenericRecordBuilder(reservedKeywordAvroSchema).set("date", "2021-10-01").set("as", "bar").set("updatedAt", initialUpdatedAt.plusMillis(2000).toEpochMilli()).build())); kafka.sendMessages(reservedKeywordRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("reserved_keyword_schema.json"), RESERVED_KEYWORD_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("reserved_keyword_realtimeSpec.json"), RESERVED_KEYWORD_TABLE); + // Create a table having quotes in column names + kafka.createTopic(QUOTES_IN_COLUMN_NAME_TABLE); + Schema quotesInColumnNameAvroSchema = SchemaBuilder.record(QUOTES_IN_COLUMN_NAME_TABLE).fields() + .name("non_quoted").type().optional().stringType() + .name("updatedAt").type().optional().longType() + .endRecord(); + ImmutableList.Builder> quotesInColumnNameRecordsBuilder = ImmutableList.builder(); + quotesInColumnNameRecordsBuilder.add(new ProducerRecord<>(QUOTES_IN_COLUMN_NAME_TABLE, "key0", new GenericRecordBuilder(quotesInColumnNameAvroSchema).set("non_quoted", "Foo").set("updatedAt", initialUpdatedAt.plusMillis(1000).toEpochMilli()).build())); + quotesInColumnNameRecordsBuilder.add(new ProducerRecord<>(QUOTES_IN_COLUMN_NAME_TABLE, "key1", new GenericRecordBuilder(quotesInColumnNameAvroSchema).set("non_quoted", "Bar").set("updatedAt", initialUpdatedAt.plusMillis(2000).toEpochMilli()).build())); + kafka.sendMessages(quotesInColumnNameRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); + pinot.createSchema(getClass().getClassLoader().getResourceAsStream("quotes_in_column_name_schema.json"), QUOTES_IN_COLUMN_NAME_TABLE); + pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("quotes_in_column_name_realtimeSpec.json"), QUOTES_IN_COLUMN_NAME_TABLE); + Map pinotProperties = ImmutableMap.builder() .put("pinot.controller-urls", pinot.getControllerConnectString()) .put("pinot.max-rows-per-split-for-segment-queries", String.valueOf(MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES)) @@ -272,7 +317,7 @@ private static Map schemaRegistryAwareProducer(TestingKafka test private static GenericRecord createTestRecord( List stringArrayColumn, - List booleanArrayColumn, + Boolean booleanColumn, List intArrayColumn, List floatArrayColumn, List doubleArrayColumn, @@ -283,10 +328,9 @@ private static GenericRecord createTestRecord( return new GenericRecordBuilder(schema) .set("string_col", stringArrayColumn.get(0)) - .set("bool_col", booleanArrayColumn.get(0)) + .set("bool_col", booleanColumn) .set("bytes_col", Hex.toHexString(stringArrayColumn.get(0).getBytes(StandardCharsets.UTF_8))) .set("string_array_col", stringArrayColumn) - .set("bool_array_col", booleanArrayColumn) .set("int_array_col", intArrayColumn) .set("int_array_col_with_pinot_default", intArrayColumn) .set("float_array_col", floatArrayColumn) @@ -313,7 +357,6 @@ private static GenericRecord createArrayNullRecord() { Schema schema = getAllTypesAvroSchema(); List stringList = Arrays.asList("string_0", null, "string_2", null, "string_4"); - List booleanList = Arrays.asList(true, null, false, null, true); List integerList = new ArrayList<>(); integerList.addAll(Arrays.asList(null, null, null, null, null)); List integerWithDefaultList = Arrays.asList(-1112, null, 753, null, -9238); @@ -325,7 +368,6 @@ private static GenericRecord createArrayNullRecord() return new GenericRecordBuilder(schema) .set("string_col", "array_null") .set("string_array_col", stringList) - .set("bool_array_col", booleanList) .set("int_array_col", integerList) .set("int_array_col_with_pinot_default", integerWithDefaultList) .set("float_array_col", floatList) @@ -351,7 +393,6 @@ private static Schema getAllTypesAvroSchema() .name("bool_col").type().optional().booleanType() .name("bytes_col").type().optional().stringType() .name("string_array_col").type().optional().array().items().nullable().stringType() - .name("bool_array_col").type().optional().array().items().nullable().booleanType() .name("int_array_col").type().optional().array().items().nullable().intType() .name("int_array_col_with_pinot_default").type().optional().array().items().nullable().intType() .name("float_array_col").type().optional().array().items().nullable().floatType() @@ -646,6 +687,29 @@ public void testReservedKeywordColumnNames() assertThat(query("SELECT \"count(*)\" FROM \"SELECT COUNT(*) FROM " + RESERVED_KEYWORD_TABLE + " ORDER BY COUNT(*)\"")) .matches("VALUES BIGINT '2'") .isFullyPushedDown(); + + assertQuery("SELECT \"as\" FROM " + RESERVED_KEYWORD_TABLE + " WHERE \"as\" = 'foo'", "VALUES 'foo'"); + assertQuery("SELECT \"as\" FROM " + RESERVED_KEYWORD_TABLE + " WHERE \"as\" IN ('foo', 'bar')", "VALUES 'foo', 'bar'"); + + assertThat(query("SELECT \"as\" FROM \"SELECT \"\"as\"\" FROM " + RESERVED_KEYWORD_TABLE + "\"")) + .matches("VALUES VARCHAR 'foo', VARCHAR 'bar'") + .isFullyPushedDown(); + + assertThat(query("SELECT \"as\" FROM \"SELECT \"\"as\"\" FROM " + RESERVED_KEYWORD_TABLE + " WHERE \"\"as\"\" = 'foo'\"")) + .matches("VALUES VARCHAR 'foo'") + .isFullyPushedDown(); + + assertThat(query("SELECT \"as\" FROM \"SELECT \"\"as\"\" FROM " + RESERVED_KEYWORD_TABLE + " WHERE \"\"as\"\" IN ('foo', 'bar')\"")) + .matches("VALUES VARCHAR 'foo', VARCHAR 'bar'") + .isFullyPushedDown(); + + assertThat(query("SELECT \"as\" FROM \"SELECT \"\"as\"\" FROM " + RESERVED_KEYWORD_TABLE + " ORDER BY \"\"as\"\"\"")) + .matches("VALUES VARCHAR 'foo', VARCHAR 'bar'") + .isFullyPushedDown(); + + assertThat(query("SELECT \"as\", \"count(*)\" FROM \"SELECT \"\"as\"\", COUNT(*) FROM " + RESERVED_KEYWORD_TABLE + " GROUP BY \"\"as\"\"\"")) + .matches("VALUES (VARCHAR 'foo', BIGINT '1'), (VARCHAR 'bar', BIGINT '1')") + .isFullyPushedDown(); } @Test @@ -690,8 +754,8 @@ public void testMaxLimitForPassthroughQueries() "Broker query returned '13' rows, maximum allowed is '12' rows. with query \"select \"updated_at_seconds\", \"string_col\" from too_many_broker_rows limit 13\""); // Pinot issue preventing Integer.MAX_VALUE from being a limit: https://github.com/apache/incubator-pinot/issues/7110 - assertQueryFails("SELECT * FROM \"SELECT string_col, long_col FROM " + ALL_TYPES_TABLE + " LIMIT " + Integer.MAX_VALUE + "\"", - "Unexpected response status: 500 for request \\{\"sql\":\"select \\\\\"string_col\\\\\", \\\\\"long_col\\\\\" from alltypes limit 2147483647\"\\} to url http://localhost:\\d+/query/sql, with headers \\{Accept=\\[application/json\\], Content-Type=\\[application/json\\]\\}, full response null"); + // This is now resolved in pinot 0.8.0 + assertQuerySucceeds("SELECT * FROM \"SELECT string_col, long_col FROM " + ALL_TYPES_TABLE + " LIMIT " + Integer.MAX_VALUE + "\""); // Pinot broker requests do not handle limits greater than Integer.MAX_VALUE // Note that -2147483648 is due to an integer overflow in Pinot: https://github.com/apache/pinot/issues/7242 @@ -778,33 +842,25 @@ public void testNullBehavior() // Default null value for strings is the string 'null' assertThat(query("SELECT string_col" + " FROM " + ALL_TYPES_TABLE + - " WHERE bytes_col = X'' AND element_at(bool_array_col, 1) = 'null'")) + " WHERE bytes_col = X'' AND element_at(string_array_col, 1) = 'null'")) .matches("VALUES (VARCHAR 'null')") .isNotFullyPushedDown(FilterNode.class); // Default array null value for strings is the string 'null' assertThat(query("SELECT element_at(string_array_col, 1)" + " FROM " + ALL_TYPES_TABLE + - " WHERE bytes_col = X'' AND element_at(bool_array_col, 1) = 'null'")) + " WHERE bytes_col = X'' AND string_col = 'null'")) .matches("VALUES (VARCHAR 'null')") - .isNotFullyPushedDown(ExchangeNode.class, ProjectNode.class, FilterNode.class); + .isNotFullyPushedDown(ExchangeNode.class, ProjectNode.class); // Default null value for booleans is the string 'null' // Booleans are treated as a string assertThat(query("SELECT bool_col" + " FROM " + ALL_TYPES_TABLE + " WHERE string_col = 'null'")) - .matches("VALUES (VARCHAR 'null')") + .matches("VALUES (false)") .isFullyPushedDown(); - // Default array null value for booleans is the string 'null' - // Boolean are treated as a string - assertThat(query("SELECT element_at(bool_array_col, 1)" + - " FROM " + ALL_TYPES_TABLE + - " WHERE string_col = 'null'")) - .matches("VALUES (VARCHAR 'null')") - .isNotFullyPushedDown(ProjectNode.class); - // Default null value for pinot BYTES type (varbinary) is the string 'null' // BYTES values are treated as a strings // BYTES arrays are not supported @@ -846,7 +902,6 @@ public void testNullBehavior() // Default value for a "null" array is 1 element with default null array value, // Values are tested above, this test is to verify pinot returns an array with 1 element. assertThat(query("SELECT CARDINALITY(string_array_col)," + - " CARDINALITY(bool_array_col)," + " CARDINALITY(int_array_col_with_pinot_default)," + " CARDINALITY(int_array_col)," + " CARDINALITY(float_array_col)," + @@ -854,13 +909,12 @@ public void testNullBehavior() " CARDINALITY(long_array_col)" + " FROM " + ALL_TYPES_TABLE + " WHERE string_col = 'null'")) - .matches("VALUES (BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1')") + .matches("VALUES (BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1')") .isNotFullyPushedDown(ProjectNode.class); // If an array contains both null and non-null values, the null values are omitted: // There are 5 values in the avro records, but only the 3 non-null values are in pinot assertThat(query("SELECT CARDINALITY(string_array_col)," + - " CARDINALITY(bool_array_col)," + " CARDINALITY(int_array_col_with_pinot_default)," + " CARDINALITY(int_array_col)," + " CARDINALITY(float_array_col)," + @@ -868,7 +922,7 @@ public void testNullBehavior() " CARDINALITY(long_array_col)" + " FROM " + ALL_TYPES_TABLE + " WHERE string_col = 'array_null'")) - .matches("VALUES (BIGINT '3', BIGINT '3', BIGINT '3', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1')") + .matches("VALUES (BIGINT '3', BIGINT '3', BIGINT '1', BIGINT '1', BIGINT '1', BIGINT '1')") .isNotFullyPushedDown(ProjectNode.class); } @@ -925,9 +979,9 @@ public void testArrayFilter() public void testLimitPushdown() { assertThat(query("SELECT string_col, long_col FROM " + "\"SELECT string_col, long_col, bool_col FROM " + ALL_TYPES_TABLE + " WHERE int_col > 0\" " + - " WHERE bool_col = 'false' LIMIT " + MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES)) + " WHERE bool_col = false LIMIT " + MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES)) .isFullyPushedDown(); - assertThat(query("SELECT string_col, long_col FROM " + ALL_TYPES_TABLE + " WHERE int_col >0 AND bool_col = 'false' LIMIT " + MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES)) + assertThat(query("SELECT string_col, long_col FROM " + ALL_TYPES_TABLE + " WHERE int_col >0 AND bool_col = false LIMIT " + MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES)) .isNotFullyPushedDown(LimitNode.class); } @@ -1365,4 +1419,334 @@ public void testDoubleWithInfinity() " WHERE string_col = 'string_0'\"")) .matches("VALUES (POWER(0, -1))"); } + + @Test + public void testTransformFunctions() + { + // Test that time units and formats are correctly uppercased. + // The dynamic table, i.e. the query between the quotes, will be lowercased since it is passed as a SchemaTableName. + assertThat(query("SELECT hours_col, hours_col2 FROM \"SELECT timeconvert(created_at_seconds, 'SECONDS', 'HOURS') as hours_col," + + " CAST(FLOOR(created_at_seconds / 3600) as long) as hours_col2 from " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '450168', BIGINT '450168')," + + " (BIGINT '450168', BIGINT '450168')," + + " (BIGINT '450168', BIGINT '450168')"); + assertThat(query("SELECT \"datetimeconvert(created_at_seconds,'1:seconds:epoch','1:days:epoch','1:days')\" FROM \"SELECT datetimeconvert(created_at_seconds, '1:SECONDS:EPOCH', '1:DAYS:EPOCH', '1:DAYS')" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '18757'), (BIGINT '18757'), (BIGINT '18757')"); + // Multiple forms of datetrunc from 2-5 arguments + assertThat(query("SELECT \"datetrunc('hour',created_at)\" FROM \"SELECT datetrunc('hour', created_at)" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '1620604800000'), (BIGINT '1620604800000'), (BIGINT '1620604800000')"); + assertThat(query("SELECT \"datetrunc('hour',created_at_seconds,'seconds')\" FROM \"SELECT datetrunc('hour', created_at_seconds, 'SECONDS')" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '1620604800'), (BIGINT '1620604800'), (BIGINT '1620604800')"); + assertThat(query("SELECT \"datetrunc('hour',created_at_seconds,'seconds','utc')\" FROM \"SELECT datetrunc('hour', created_at_seconds, 'SECONDS', 'UTC')" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '1620604800'), (BIGINT '1620604800'), (BIGINT '1620604800')"); + + assertThat(query("SELECT \"datetrunc('quarter',created_at_seconds,'seconds','america/los_angeles','hours')\" FROM \"SELECT datetrunc('quarter', created_at_seconds, 'SECONDS', 'America/Los_Angeles', 'HOURS')" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '449239'), (BIGINT '449239'), (BIGINT '449239')"); + assertThat(query("SELECT \"arraylength(double_array_col)\" FROM " + + "\"SELECT arraylength(double_array_col)" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col in ('string_0', 'array_null')\"")) + .matches("VALUES (3), (1)"); + + assertThat(query("SELECT \"cast(floor(arrayaverage(long_array_col)),'long')\" FROM " + + "\"SELECT cast(floor(arrayaverage(long_array_col)) as long)" + + " FROM " + ALL_TYPES_TABLE + + " WHERE double_array_col is not null and double_col != -17.33\"")) + .matches("VALUES (BIGINT '333333337')," + + " (BIGINT '333333338')," + + " (BIGINT '333333338')," + + " (BIGINT '333333338')," + + " (BIGINT '333333339')," + + " (BIGINT '333333339')," + + " (BIGINT '333333339')," + + " (BIGINT '333333340')"); + + assertThat(query("SELECT \"arraymax(long_array_col)\" FROM " + + "\"SELECT arraymax(long_array_col)" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col is not null and string_col != 'array_null'\"")) + .matches("VALUES (BIGINT '4147483647')," + + " (BIGINT '4147483648')," + + " (BIGINT '4147483649')," + + " (BIGINT '4147483650')," + + " (BIGINT '4147483651')," + + " (BIGINT '4147483652')," + + " (BIGINT '4147483653')," + + " (BIGINT '4147483654')," + + " (BIGINT '4147483655')"); + + assertThat(query("SELECT \"arraymin(long_array_col)\" FROM " + + "\"SELECT arraymin(long_array_col)" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col is not null and string_col != 'array_null'\"")) + .matches("VALUES (BIGINT '-3147483647')," + + " (BIGINT '-3147483646')," + + " (BIGINT '-3147483645')," + + " (BIGINT '-3147483644')," + + " (BIGINT '-3147483643')," + + " (BIGINT '-3147483642')," + + " (BIGINT '-3147483641')," + + " (BIGINT '-3147483640')," + + " (BIGINT '-3147483639')"); + } + + @Test + public void testPassthroughQueriesWithAliases() + { + assertThat(query("SELECT hours_col, hours_col2 FROM " + + "\"SELECT timeconvert(created_at_seconds, 'SECONDS', 'HOURS') AS hours_col," + + " CAST(FLOOR(created_at_seconds / 3600) as long) as hours_col2" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '450168', BIGINT '450168'), (BIGINT '450168', BIGINT '450168'), (BIGINT '450168', BIGINT '450168')"); + + // Test without aliases to verify fieldName is correctly handled + assertThat(query("SELECT \"timeconvert(created_at_seconds,'seconds','hours')\"," + + " \"cast(floor(divide(created_at_seconds,'3600')),'long')\" FROM " + + "\"SELECT timeconvert(created_at_seconds, 'SECONDS', 'HOURS')," + + " CAST(FLOOR(created_at_seconds / 3600) as long)" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '450168', BIGINT '450168'), (BIGINT '450168', BIGINT '450168'), (BIGINT '450168', BIGINT '450168')"); + + assertThat(query("SELECT int_col2, long_col2 FROM " + + "\"SELECT int_col AS int_col2, long_col AS long_col2" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"")) + .matches("VALUES (54, BIGINT '-3147483647')," + + " (54, BIGINT '-3147483646')," + + " (54, BIGINT '-3147483645')," + + " (55, BIGINT '-3147483644')," + + " (55, BIGINT '-3147483643')," + + " (55, BIGINT '-3147483642')," + + " (56, BIGINT '-3147483641')," + + " (56, BIGINT '-3147483640')," + + " (56, BIGINT '-3147483639')"); + + assertThat(query("SELECT int_col2, long_col2 FROM " + + "\"SELECT int_col AS int_col2, long_col AS long_col2 " + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"")) + .matches("VALUES (54, BIGINT '-3147483647')," + + " (54, BIGINT '-3147483646')," + + " (54, BIGINT '-3147483645')," + + " (55, BIGINT '-3147483644')," + + " (55, BIGINT '-3147483643')," + + " (55, BIGINT '-3147483642')," + + " (56, BIGINT '-3147483641')," + + " (56, BIGINT '-3147483640')," + + " (56, BIGINT '-3147483639')"); + + // Query with a function on a column and an alias with the same column name fails + // For more details see https://github.com/apache/pinot/issues/7545 + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> query("SELECT int_col FROM " + + "\"SELECT floor(int_col / 3) AS int_col" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"")) + .withRootCauseInstanceOf(RuntimeException.class) + .withMessage("Alias int_col cannot be referred in SELECT Clause"); + } + + @Test + public void testPassthroughQueriesWithPushdowns() + { + assertThat(query("SELECT DISTINCT \"timeconvert(created_at_seconds,'seconds','hours')\"," + + " \"cast(floor(divide(created_at_seconds,'3600')),'long')\" FROM " + + "\"SELECT timeconvert(created_at_seconds, 'SECONDS', 'HOURS')," + + " CAST(FLOOR(created_at_seconds / 3600) AS long)" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '450168', BIGINT '450168')"); + + assertThat(query("SELECT DISTINCT \"timeconvert(created_at_seconds,'seconds','milliseconds')\"," + + " \"cast(floor(divide(created_at_seco" + + "nds,'3600')),'long')\" FROM " + + "\"SELECT timeconvert(created_at_seconds, 'SECONDS', 'MILLISECONDS')," + + " CAST(FLOOR(created_at_seconds / 3600) as long)" + + " FROM " + DATE_TIME_FIELDS_TABLE + "\"")) + .matches("VALUES (BIGINT '1620604802000', BIGINT '450168')," + + " (BIGINT '1620604801000', BIGINT '450168')," + + " (BIGINT '1620604800000', BIGINT '450168')"); + + assertThat(query("SELECT int_col, sum(long_col) FROM " + + "\"SELECT int_col, long_col" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"" + + " GROUP BY int_col")) + .isFullyPushedDown(); + + assertThat(query("SELECT DISTINCT int_col, long_col FROM " + + "\"SELECT int_col, long_col FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"")) + .isFullyPushedDown(); + + assertThat(query("SELECT int_col2, long_col2, count(*) FROM " + + "\"SELECT int_col AS int_col2, long_col AS long_col2" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"" + + " GROUP BY int_col2, long_col2")) + .isFullyPushedDown(); + + // Query with grouping columns but no aggregates ignores aliases. + // For more details see: https://github.com/apache/pinot/issues/7546 + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> query("SELECT DISTINCT int_col2, long_col2 FROM " + + "\"SELECT int_col AS int_col2, long_col AS long_col2" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"")) + .withRootCauseInstanceOf(RuntimeException.class) + .withMessage("java.lang.NullPointerException"); + + assertThat(query("SELECT int_col2, count(*) FROM " + + "\"SELECT int_col AS int_col2, long_col AS long_col2" + + " FROM " + ALL_TYPES_TABLE + + " WHERE string_col IS NOT null AND string_col != 'array_null'\"" + + " GROUP BY int_col2")) + .isFullyPushedDown(); + } + + @Test + public void testColumnNamesWithDoubleQuotes() + { + assertThat(query("select \"double\"\"\"\"qu\"\"ot\"\"ed\"\"\" from quotes_in_column_name")) + .matches("VALUES (VARCHAR 'foo'), (VARCHAR 'bar')") + .isFullyPushedDown(); + + assertThat(query("select \"qu\"\"ot\"\"ed\" from quotes_in_column_name")) + .matches("VALUES (VARCHAR 'FOO'), (VARCHAR 'BAR')") + .isFullyPushedDown(); + + assertThat(query("select non_quoted from \"select \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" as non_quoted from quotes_in_column_name\"")) + .matches("VALUES (VARCHAR 'FOO'), (VARCHAR 'BAR')") + .isFullyPushedDown(); + + assertThat(query("select \"qu\"\"ot\"\"ed\" from \"select non_quoted as \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\"")) + .matches("VALUES (VARCHAR 'Foo'), (VARCHAR 'Bar')") + .isFullyPushedDown(); + + assertThat(query("select \"double\"\"\"\"qu\"\"ot\"\"ed\"\"\" from \"select \"\"double\"\"\"\"\"\"\"\"qu\"\"\"\"ot\"\"\"\"ed\"\"\"\"\"\" from quotes_in_column_name\"")) + .matches("VALUES (VARCHAR 'foo'), (VARCHAR 'bar')") + .isFullyPushedDown(); + + assertThat(query("select \"qu\"\"oted\" from \"select \"\"double\"\"\"\"\"\"\"\"qu\"\"\"\"ot\"\"\"\"ed\"\"\"\"\"\" as \"\"qu\"\"\"\"oted\"\" from quotes_in_column_name\"")) + .matches("VALUES (VARCHAR 'foo'), (VARCHAR 'bar')") + .isFullyPushedDown(); + + assertThat(query("select \"date\" from \"select \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" as \"\"date\"\" from quotes_in_column_name\"")) + .matches("VALUES (VARCHAR 'FOO'), (VARCHAR 'BAR')") + .isFullyPushedDown(); + + assertThat(query("select \"date\" from \"select non_quoted as \"\"date\"\" from quotes_in_column_name\"")) + .matches("VALUES (VARCHAR 'Foo'), (VARCHAR 'Bar')") + .isFullyPushedDown(); + + /// Test aggregations with double quoted columns + assertThat(query("select non_quoted, COUNT(DISTINCT \"date\") from \"select non_quoted, non_quoted as \"\"date\"\" from quotes_in_column_name\" GROUP BY non_quoted")) + .isFullyPushedDown(); + + assertThat(query("select non_quoted, COUNT(DISTINCT \"double\"\"\"\"qu\"\"ot\"\"ed\"\"\") from \"select non_quoted, \"\"double\"\"\"\"\"\"\"\"qu\"\"\"\"ot\"\"\"\"ed\"\"\"\"\"\" from quotes_in_column_name\" GROUP BY non_quoted")) + .isFullyPushedDown(); + + assertThat(query("select non_quoted, COUNT(DISTINCT \"qu\"\"ot\"\"ed\") from \"select non_quoted, \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\" GROUP BY non_quoted")) + .isFullyPushedDown(); + + assertThat(query("select non_quoted, COUNT(DISTINCT \"qu\"\"ot\"\"ed\") from \"select non_quoted, non_quoted as \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\" GROUP BY non_quoted")) + .isFullyPushedDown(); + + assertThat(query("select \"qu\"\"ot\"\"ed\", COUNT(DISTINCT \"date\") from \"select \"\"qu\"\"\"\"ot\"\"\"\"ed\"\", non_quoted as \"\"date\"\" from quotes_in_column_name\" GROUP BY \"qu\"\"ot\"\"ed\"")) + .isFullyPushedDown(); + + assertThat(query("select \"qu\"\"ot\"\"ed\", COUNT(DISTINCT \"double\"\"\"\"qu\"\"ot\"\"ed\"\"\") from \"select \"\"qu\"\"\"\"ot\"\"\"\"ed\"\", \"\"double\"\"\"\"\"\"\"\"qu\"\"\"\"ot\"\"\"\"ed\"\"\"\"\"\" from quotes_in_column_name\" GROUP BY \"qu\"\"ot\"\"ed\"")) + .isFullyPushedDown(); + + // Test with grouping column that has double quotes aliased to a name without double quotes + assertThat(query("select non_quoted, COUNT(DISTINCT \"qu\"\"ot\"\"ed\") from \"select \"\"double\"\"\"\"\"\"\"\"qu\"\"\"\"ot\"\"\"\"ed\"\"\"\"\"\" as non_quoted, \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\" GROUP BY non_quoted")) + .isFullyPushedDown(); + + // Test with grouping column that has no double quotes aliased to a name with double quotes + assertThat(query("select \"qu\"\"oted\", COUNT(DISTINCT \"qu\"\"ot\"\"ed\") from \"select non_quoted as \"\"qu\"\"\"\"oted\"\", \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\" GROUP BY \"qu\"\"oted\"")) + .isFullyPushedDown(); + + assertThat(query("select \"qu\"\"oted\", COUNT(DISTINCT \"qu\"\"oted\") from \"select \"\"qu\"\"\"\"ot\"\"\"\"ed\"\", non_quoted as \"\"qu\"\"\"\"oted\"\" from quotes_in_column_name\" GROUP BY \"qu\"\"oted\"")) + .isFullyPushedDown(); + + /// Test aggregations with double quoted columns and no grouping sets + assertThat(query("select COUNT(DISTINCT \"date\") from \"select non_quoted as \"\"date\"\" from quotes_in_column_name\"")) + .isFullyPushedDown(); + + assertThat(query("select COUNT(DISTINCT \"double\"\"\"\"qu\"\"ot\"\"ed\"\"\") from \"select \"\"double\"\"\"\"\"\"\"\"qu\"\"\"\"ot\"\"\"\"ed\"\"\"\"\"\" from quotes_in_column_name\"")) + .isFullyPushedDown(); + + assertThat(query("select COUNT(DISTINCT \"qu\"\"ot\"\"ed\") from \"select \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\"")) + .isFullyPushedDown(); + + assertThat(query("select COUNT(DISTINCT \"qu\"\"ot\"\"ed\") from \"select non_quoted as \"\"qu\"\"\"\"ot\"\"\"\"ed\"\" from quotes_in_column_name\"")) + .isFullyPushedDown(); + } + + @Test + public void testLimitAndOffsetWithPushedDownAggregates() + { + // Aggregation pushdown must be disabled when there is an offset as the results will not be correct + assertThat(query("SELECT COUNT(*), MAX(long_col)" + + " FROM \"SELECT long_col FROM " + ALL_TYPES_TABLE + + " WHERE long_col < 0" + + " ORDER BY long_col " + + " LIMIT 5, 6\"")) + .matches("VALUES (BIGINT '4', BIGINT '-3147483639')") + .isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class, AggregationNode.class); + + assertThat(query("SELECT long_col, COUNT(*), MAX(long_col)" + + " FROM \"SELECT long_col FROM " + ALL_TYPES_TABLE + + " WHERE long_col < 0" + + " ORDER BY long_col " + + " LIMIT 5, 6\" GROUP BY long_col")) + .matches("VALUES (BIGINT '-3147483642', BIGINT '1', BIGINT '-3147483642')," + + " (BIGINT '-3147483640', BIGINT '1', BIGINT '-3147483640')," + + " (BIGINT '-3147483641', BIGINT '1', BIGINT '-3147483641')," + + " (BIGINT '-3147483639', BIGINT '1', BIGINT '-3147483639')") + .isNotFullyPushedDown(ExchangeNode.class, AggregationNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class, AggregationNode.class); + + assertThat(query("SELECT long_col, string_col, COUNT(*), MAX(long_col)" + + " FROM \"SELECT * FROM " + ALL_TYPES_TABLE + + " WHERE long_col < 0" + + " ORDER BY long_col, string_col" + + " LIMIT 5, 6\" GROUP BY long_col, string_col")) + .matches("VALUES (BIGINT '-3147483641', VARCHAR 'string_7200', BIGINT '1', BIGINT '-3147483641')," + + " (BIGINT '-3147483640', VARCHAR 'string_8400', BIGINT '1', BIGINT '-3147483640')," + + " (BIGINT '-3147483642', VARCHAR 'string_6000', BIGINT '1', BIGINT '-3147483642')," + + " (BIGINT '-3147483639', VARCHAR 'string_9600', BIGINT '1', BIGINT '-3147483639')") + .isNotFullyPushedDown(ExchangeNode.class, ProjectNode.class, AggregationNode.class, ExchangeNode.class, ExchangeNode.class, AggregationNode.class, ProjectNode.class); + + // Note that the offset is the first parameter + assertThat(query("SELECT long_col" + + " FROM \"SELECT long_col FROM " + ALL_TYPES_TABLE + + " WHERE long_col < 0" + + " ORDER BY long_col " + + " LIMIT 2, 6\"")) + .matches("VALUES (BIGINT '-3147483645')," + + " (BIGINT '-3147483644')," + + " (BIGINT '-3147483643')," + + " (BIGINT '-3147483642')," + + " (BIGINT '-3147483641')," + + " (BIGINT '-3147483640')") + .isFullyPushedDown(); + + // Note that the offset is the first parameter + assertThat(query("SELECT long_col, string_col" + + " FROM \"SELECT long_col, string_col FROM " + ALL_TYPES_TABLE + + " WHERE long_col < 0" + + " ORDER BY long_col " + + " LIMIT 2, 6\"")) + .matches("VALUES (BIGINT '-3147483645', VARCHAR 'string_2400')," + + " (BIGINT '-3147483644', VARCHAR 'string_3600')," + + " (BIGINT '-3147483643', VARCHAR 'string_4800')," + + " (BIGINT '-3147483642', VARCHAR 'string_6000')," + + " (BIGINT '-3147483641', VARCHAR 'string_7200')," + + " (BIGINT '-3147483640', VARCHAR 'string_8400')") + .isFullyPushedDown(); + } } diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotQueryBase.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotQueryBase.java index 2eaed88e38a7..2b2d8ebfccc8 100755 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotQueryBase.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotQueryBase.java @@ -42,7 +42,7 @@ public class TestPinotQueryBase protected List getColumnNames(String table) { return pinotMetadata.getPinotColumns(table).stream() - .map(PinotColumn::getName) + .map(PinotColumnHandle::getColumnName) .collect(toImmutableList()); } @@ -104,6 +104,10 @@ public static Map getTestingMetadata() .addSingleValueDimension("float_col", DataType.FLOAT) .addSingleValueDimension("bytes_col", DataType.BYTES) .build()) + .put("quotes_in_column_names", new SchemaBuilder().setSchemaName("quotes_in_column_names") + .addSingleValueDimension("non_quoted", DataType.STRING) + .addSingleValueDimension("qu\"ot\"ed", DataType.STRING) + .build()) .build(); } } diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotSplitManager.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotSplitManager.java index 70359cbc99ed..32d8da633000 100755 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotSplitManager.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotSplitManager.java @@ -52,7 +52,7 @@ public class TestPinotSplitManager public void testSplitsBroker() { SchemaTableName schemaTableName = new SchemaTableName("default", format("SELECT %s, %s FROM %s LIMIT %d", "AirlineID", "OriginStateName", "airlineStats", 100)); - DynamicTable dynamicTable = buildFromPql(pinotMetadata, schemaTableName); + DynamicTable dynamicTable = buildFromPql(pinotMetadata, schemaTableName, mockClusterInfoFetcher); PinotTableHandle pinotTableHandle = new PinotTableHandle("default", dynamicTable.getTableName(), TupleDomain.all(), OptionalLong.empty(), Optional.of(dynamicTable)); List splits = getSplitsHelper(pinotTableHandle, 1, false); diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java index cc3602c84b09..399029a3d84d 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java @@ -48,7 +48,7 @@ public class TestingPinotCluster implements Closeable { - private static final String BASE_IMAGE = "apachepinot/pinot:0.6.0"; + private static final String BASE_IMAGE = "apachepinot/pinot:0.8.0-jdk11"; private static final String ZOOKEEPER_INTERNAL_HOST = "zookeeper"; private static final JsonCodec> LIST_JSON_CODEC = listJsonCodec(String.class); private static final JsonCodec PINOT_SUCCESS_RESPONSE_JSON_CODEC = jsonCodec(PinotSuccessResponse.class); diff --git a/plugin/trino-pinot/src/test/resources/alltypes_schema.json b/plugin/trino-pinot/src/test/resources/alltypes_schema.json index 47a8ba15ac19..174e056de911 100644 --- a/plugin/trino-pinot/src/test/resources/alltypes_schema.json +++ b/plugin/trino-pinot/src/test/resources/alltypes_schema.json @@ -18,11 +18,6 @@ "dataType": "STRING", "singleValueField": false }, - { - "name": "bool_array_col", - "dataType": "BOOLEAN", - "singleValueField": false - }, { "name": "int_array_col", "dataType": "INT", diff --git a/plugin/trino-pinot/src/test/resources/date_time_fields_realtimeSpec.json b/plugin/trino-pinot/src/test/resources/date_time_fields_realtimeSpec.json new file mode 100644 index 000000000000..dacc316a659a --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/date_time_fields_realtimeSpec.json @@ -0,0 +1,45 @@ +{ + "tableName": "date_time_fields", + "tableType": "REALTIME", + "segmentsConfig": { + "timeColumnName": "updated_at_seconds", + "timeType": "SECONDS", + "retentionTimeUnit": "DAYS", + "retentionTimeValue": "365", + "segmentPushType": "APPEND", + "segmentPushFrequency": "daily", + "segmentAssignmentStrategy": "BalanceNumSegmentAssignmentStrategy", + "schemaName": "date_time_fields", + "replicasPerPartition": "1" + }, + "tenants": { + "broker": "DefaultTenant", + "server": "DefaultTenant" + }, + "tableIndexConfig": { + "loadMode": "MMAP", + "invertedIndexColumns": ["string_col"], + "sortedColumn": ["updated_at_seconds"], + "starTreeIndexConfigs": [], + "nullHandlingEnabled": "true", + "streamConfigs": { + "streamType": "kafka", + "stream.kafka.consumer.type": "LowLevel", + "stream.kafka.topic.name": "date_time_fields", + "stream.kafka.decoder.class.name": "org.apache.pinot.plugin.inputformat.avro.confluent.KafkaConfluentSchemaRegistryAvroMessageDecoder", + "stream.kafka.consumer.factory.class.name": "org.apache.pinot.plugin.stream.kafka20.KafkaConsumerFactory", + "stream.kafka.decoder.prop.schema.registry.rest.url": "http://schema-registry:8081", + "stream.kafka.zk.broker.url": "zookeeper:2181/", + "stream.kafka.broker.list": "kafka:9092", + "realtime.segment.flush.threshold.time": "1m", + "realtime.segment.flush.threshold.size": "0", + "realtime.segment.flush.desired.size": "1M", + "isolation.level": "read_committed", + "stream.kafka.consumer.prop.auto.offset.reset": "smallest", + "stream.kafka.consumer.prop.group.id": "pinot_date_time_fields" + } + }, + "metadata": { + "customConfigs": {} + } +} diff --git a/plugin/trino-pinot/src/test/resources/date_time_fields_schema.json b/plugin/trino-pinot/src/test/resources/date_time_fields_schema.json new file mode 100644 index 000000000000..18e989b06b00 --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/date_time_fields_schema.json @@ -0,0 +1,34 @@ +{ + "schemaName": "date_time_fields", + "dimensionFieldSpecs": [ + { + "name": "string_col", + "dataType": "STRING" + } + ], + "dateTimeFieldSpecs": [ + { + "name": "updated_at_seconds", + "dataType": "LONG", + "defaultNullValue" : 0, + "format": "1:SECONDS:EPOCH", + "transformFunction": "toEpochSeconds(updated_at)", + "granularity" : "1:SECONDS" + }, + { + "name": "created_at", + "dataType": "LONG", + "defaultNullValue" : 0, + "format": "1:MILLISECONDS:EPOCH", + "granularity" : "1:MILLISECONDS" + }, + { + "name": "created_at_seconds", + "dataType": "LONG", + "defaultNullValue" : 0, + "format": "1:SECONDS:EPOCH", + "transformFunction": "toEpochSeconds(created_at)", + "granularity" : "1:SECONDS" + } + ] +} diff --git a/plugin/trino-pinot/src/test/resources/quotes_in_column_name_realtimeSpec.json b/plugin/trino-pinot/src/test/resources/quotes_in_column_name_realtimeSpec.json new file mode 100644 index 000000000000..a8e8a1604c57 --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/quotes_in_column_name_realtimeSpec.json @@ -0,0 +1,43 @@ +{ + "tableName": "quotes_in_column_name", + "tableType": "REALTIME", + "segmentsConfig": { + "timeColumnName": "updated_at_seconds", + "timeType": "SECONDS", + "retentionTimeUnit": "DAYS", + "retentionTimeValue": "365", + "segmentPushType": "APPEND", + "segmentPushFrequency": "daily", + "segmentAssignmentStrategy": "BalanceNumSegmentAssignmentStrategy", + "schemaName": "quotes_in_column_name", + "replicasPerPartition": "1" + }, + "tenants": { + "broker": "DefaultTenant", + "server": "DefaultTenant" + }, + "tableIndexConfig": { + "loadMode": "MMAP", + "invertedIndexColumns": [], + "sortedColumn": ["updated_at_seconds"], + "streamConfigs": { + "streamType": "kafka", + "stream.kafka.consumer.type": "LowLevel", + "stream.kafka.topic.name": "quotes_in_column_name", + "stream.kafka.decoder.class.name": "org.apache.pinot.plugin.inputformat.avro.confluent.KafkaConfluentSchemaRegistryAvroMessageDecoder", + "stream.kafka.consumer.factory.class.name": "org.apache.pinot.plugin.stream.kafka20.KafkaConsumerFactory", + "stream.kafka.decoder.prop.schema.registry.rest.url": "http://schema-registry:8081", + "stream.kafka.zk.broker.url": "zookeeper:2181/", + "stream.kafka.broker.list": "kafka:9092", + "realtime.segment.flush.threshold.time": "1m", + "realtime.segment.flush.threshold.size": "0", + "realtime.segment.flush.desired.size": "1M", + "isolation.level": "read_committed", + "stream.kafka.consumer.prop.auto.offset.reset": "smallest", + "stream.kafka.consumer.prop.group.id": "pinot_quotes_in_column_name" + } + }, + "metadata": { + "customConfigs": {} + } +} diff --git a/plugin/trino-pinot/src/test/resources/quotes_in_column_name_schema.json b/plugin/trino-pinot/src/test/resources/quotes_in_column_name_schema.json new file mode 100644 index 000000000000..9a9a8008dcbf --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/quotes_in_column_name_schema.json @@ -0,0 +1,29 @@ +{ + "schemaName": "quotes_in_column_name", + "dimensionFieldSpecs": [ + { + "name": "qu\"ot\"ed", + "dataType": "STRING", + "transformFunction": "upper(non_quoted)" + }, + { + "name": "double\"\"qu\"ot\"ed\"", + "dataType": "STRING", + "transformFunction": "lower(non_quoted)" + }, + { + "name": "non_quoted", + "dataType": "STRING" + } + ], + "dateTimeFieldSpecs": [ + { + "name": "updated_at_seconds", + "dataType": "LONG", + "defaultNullValue" : 0, + "format": "1:SECONDS:EPOCH", + "transformFunction": "toEpochSeconds(updatedAt)", + "granularity" : "1:SECONDS" + } + ] +} diff --git a/plugin/trino-pinot/src/test/resources/reserved_keyword_schema.json b/plugin/trino-pinot/src/test/resources/reserved_keyword_schema.json index fbded6212a24..fbe8f11bfc7b 100644 --- a/plugin/trino-pinot/src/test/resources/reserved_keyword_schema.json +++ b/plugin/trino-pinot/src/test/resources/reserved_keyword_schema.json @@ -4,6 +4,10 @@ { "name": "date", "dataType": "STRING" + }, + { + "name": "as", + "dataType": "STRING" } ], "dateTimeFieldSpecs": [