diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 3c3bafebf56e..622e2bc998a2 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,42 +1,20 @@ - + ## Description - - - - -> Is this change a fix, improvement, new feature, refactoring, or other? - -> Is this a change to the core query engine, a connector, client library, or the SPI interfaces? (be specific) - -> How would you describe this change to a non-technical end user or system administrator? - -## Related issues, pull requests, and links - - + +## Non-technical explanation -## Documentation -( ) No documentation is needed. -( ) Sufficient documentation is included in this PR. -( ) Documentation PR is available with #prnumber. -( ) Documentation issue #issuenumber is filed, and can be handled later. + ## Release notes -( ) No release notes entries required. -( ) Release notes entries required with the following suggested text: +( ) This is not user-visible and no release notes are required. +( ) Release notes are required, please propose a release note for me. +( ) Release notes are required, with the following suggested text: ```markdown # Section diff --git a/client/trino-cli/pom.xml b/client/trino-cli/pom.xml index 72a6696e7dcc..67529ddc8522 100644 --- a/client/trino-cli/pom.xml +++ b/client/trino-cli/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/client/trino-cli/src/main/java/io/trino/cli/ClientOptions.java b/client/trino-cli/src/main/java/io/trino/cli/ClientOptions.java index 182ca96830d1..8bdf65a1f94d 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/ClientOptions.java +++ b/client/trino-cli/src/main/java/io/trino/cli/ClientOptions.java @@ -39,7 +39,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; import static io.trino.client.KerberosUtil.defaultCredentialCachePath; -import static java.util.Collections.emptyMap; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static picocli.CommandLine.Option; @@ -219,27 +218,25 @@ public String getKeyMap() public ClientSession toClientSession() { - return new ClientSession( - parseServer(server), - user, - sessionUser, - source, - Optional.ofNullable(traceToken), - parseClientTags(nullToEmpty(clientTags)), - clientInfo, - catalog, - schema, - null, - timeZone, - Locale.getDefault(), - toResourceEstimates(resourceEstimates), - toProperties(sessionProperties), - emptyMap(), - emptyMap(), - toExtraCredentials(extraCredentials), - null, - clientRequestTimeout, - disableCompression); + return ClientSession.builder() + .server(parseServer(server)) + .principal(user) + .user(sessionUser) + .source(source) + .traceToken(Optional.ofNullable(traceToken)) + .clientTags(parseClientTags(nullToEmpty(clientTags))) + .clientInfo(clientInfo) + .catalog(catalog) + .schema(schema) + .timeZone(timeZone) + .locale(Locale.getDefault()) + .resourceEstimates(toResourceEstimates(resourceEstimates)) + .properties(toProperties(sessionProperties)) + .credentials(toExtraCredentials(extraCredentials)) + .transactionId(null) + .clientRequestTimeout(clientRequestTimeout) + .compressionDisabled(disableCompression) + .build(); } public static URI parseServer(String server) diff --git a/client/trino-cli/src/main/java/io/trino/cli/Console.java b/client/trino-cli/src/main/java/io/trino/cli/Console.java index 180d7b7f0f35..b79a347c881f 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Console.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Console.java @@ -366,8 +366,8 @@ private static boolean process( // update catalog and schema if present if (query.getSetCatalog().isPresent() || query.getSetSchema().isPresent()) { session = ClientSession.builder(session) - .withCatalog(query.getSetCatalog().orElse(session.getCatalog())) - .withSchema(query.getSetSchema().orElse(session.getSchema())) + .catalog(query.getSetCatalog().orElse(session.getCatalog())) + .schema(query.getSetSchema().orElse(session.getSchema())) .build(); } @@ -379,12 +379,12 @@ private static boolean process( ClientSession.Builder builder = ClientSession.builder(session); if (query.getStartedTransactionId() != null) { - builder = builder.withTransactionId(query.getStartedTransactionId()); + builder = builder.transactionId(query.getStartedTransactionId()); } // update path if present if (query.getSetPath().isPresent()) { - builder = builder.withPath(query.getSetPath().get()); + builder = builder.path(query.getSetPath().get()); } // update session properties if present @@ -392,14 +392,14 @@ private static boolean process( Map sessionProperties = new HashMap<>(session.getProperties()); sessionProperties.putAll(query.getSetSessionProperties()); sessionProperties.keySet().removeAll(query.getResetSessionProperties()); - builder = builder.withProperties(sessionProperties); + builder = builder.properties(sessionProperties); } // update session roles if (!query.getSetRoles().isEmpty()) { Map roles = new HashMap<>(session.getRoles()); roles.putAll(query.getSetRoles()); - builder = builder.withRoles(roles); + builder = builder.roles(roles); } // update prepared statements if present @@ -407,7 +407,7 @@ private static boolean process( Map preparedStatements = new HashMap<>(session.getPreparedStatements()); preparedStatements.putAll(query.getAddedPreparedStatements()); preparedStatements.keySet().removeAll(query.getDeallocatedPreparedStatements()); - builder = builder.withPreparedStatements(preparedStatements); + builder = builder.preparedStatements(preparedStatements); } session = builder.build(); diff --git a/client/trino-cli/src/main/java/io/trino/cli/OutputHandler.java b/client/trino-cli/src/main/java/io/trino/cli/OutputHandler.java index 8a13cf6ed2e0..73a2ecd3bed5 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/OutputHandler.java +++ b/client/trino-cli/src/main/java/io/trino/cli/OutputHandler.java @@ -94,7 +94,7 @@ public void processRows(StatementClient client) if (row == END_TOKEN) { break; } - else if (row != null) { + if (row != null) { rowBuffer.add(row); } } diff --git a/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java b/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java index 536e23700595..e7acc3e0fc22 100644 --- a/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java +++ b/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java @@ -14,8 +14,6 @@ package io.trino.cli; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.airlift.json.JsonCodec; import io.airlift.units.Duration; import io.trino.client.ClientSession; @@ -102,27 +100,19 @@ public void testCookie() static ClientSession createClientSession(MockWebServer server) { - return new ClientSession( - server.url("/").uri(), - Optional.of("user"), - Optional.empty(), - "source", - Optional.empty(), - ImmutableSet.of(), - "clientInfo", - "catalog", - "schema", - null, - ZoneId.of("America/Los_Angeles"), - Locale.ENGLISH, - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - null, - new Duration(2, MINUTES), - true); + return ClientSession.builder() + .server(server.url("/").uri()) + .principal(Optional.of("user")) + .source("source") + .clientInfo("clientInfo") + .catalog("catalog") + .schema("schema") + .timeZone(ZoneId.of("America/Los_Angeles")) + .locale(Locale.ENGLISH) + .transactionId(null) + .clientRequestTimeout(new Duration(2, MINUTES)) + .compressionDisabled(true) + .build(); } static String createResults(MockWebServer server) diff --git a/client/trino-client/pom.xml b/client/trino-client/pom.xml index 55c35a139283..f0f43a825022 100644 --- a/client/trino-client/pom.xml +++ b/client/trino-client/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/client/trino-client/src/main/java/io/trino/client/ClientSession.java b/client/trino-client/src/main/java/io/trino/client/ClientSession.java index 06819da0ef35..8fd96c6c7e53 100644 --- a/client/trino-client/src/main/java/io/trino/client/ClientSession.java +++ b/client/trino-client/src/main/java/io/trino/client/ClientSession.java @@ -54,6 +54,11 @@ public class ClientSession private final Duration clientRequestTimeout; private final boolean compressionDisabled; + public static Builder builder() + { + return new Builder(); + } + public static Builder builder(ClientSession clientSession) { return new Builder(clientSession); @@ -62,11 +67,11 @@ public static Builder builder(ClientSession clientSession) public static ClientSession stripTransactionId(ClientSession session) { return ClientSession.builder(session) - .withoutTransactionId() + .transactionId(null) .build(); } - public ClientSession( + private ClientSession( URI server, Optional principal, Optional user, @@ -89,8 +94,8 @@ public ClientSession( boolean compressionDisabled) { this.server = requireNonNull(server, "server is null"); - this.principal = principal; - this.user = user; + this.principal = requireNonNull(principal, "principal is null"); + this.user = requireNonNull(user, "user is null"); this.source = source; this.traceToken = requireNonNull(traceToken, "traceToken is null"); this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null")); @@ -270,26 +275,28 @@ public String toString() public static final class Builder { private URI server; - private Optional principal; - private Optional user; + private Optional principal = Optional.empty(); + private Optional user = Optional.empty(); private String source; - private Optional traceToken; - private Set clientTags; + private Optional traceToken = Optional.empty(); + private Set clientTags = ImmutableSet.of(); private String clientInfo; private String catalog; private String schema; private String path; private ZoneId timeZone; private Locale locale; - private Map resourceEstimates; - private Map properties; - private Map preparedStatements; - private Map roles; - private Map credentials; + private Map resourceEstimates = ImmutableMap.of(); + private Map properties = ImmutableMap.of(); + private Map preparedStatements = ImmutableMap.of(); + private Map roles = ImmutableMap.of(); + private Map credentials = ImmutableMap.of(); private String transactionId; private Duration clientRequestTimeout; private boolean compressionDisabled; + private Builder() {} + private Builder(ClientSession clientSession) { requireNonNull(clientSession, "clientSession is null"); @@ -315,61 +322,121 @@ private Builder(ClientSession clientSession) compressionDisabled = clientSession.isCompressionDisabled(); } - public Builder withCatalog(String catalog) + public Builder server(URI server) + { + this.server = server; + return this; + } + + public Builder user(Optional user) + { + this.user = user; + return this; + } + + public Builder principal(Optional principal) + { + this.principal = principal; + return this; + } + + public Builder source(String source) + { + this.source = source; + return this; + } + + public Builder traceToken(Optional traceToken) + { + this.traceToken = traceToken; + return this; + } + + public Builder clientTags(Set clientTags) + { + this.clientTags = clientTags; + return this; + } + + public Builder clientInfo(String clientInfo) + { + this.clientInfo = clientInfo; + return this; + } + + public Builder catalog(String catalog) + { + this.catalog = catalog; + return this; + } + + public Builder schema(String schema) + { + this.schema = schema; + return this; + } + + public Builder path(String path) + { + this.path = path; + return this; + } + + public Builder timeZone(ZoneId timeZone) { - this.catalog = requireNonNull(catalog, "catalog is null"); + this.timeZone = timeZone; return this; } - public Builder withSchema(String schema) + public Builder locale(Locale locale) { - this.schema = requireNonNull(schema, "schema is null"); + this.locale = locale; return this; } - public Builder withPath(String path) + public Builder resourceEstimates(Map resourceEstimates) { - this.path = requireNonNull(path, "path is null"); + this.resourceEstimates = resourceEstimates; return this; } - public Builder withProperties(Map properties) + public Builder properties(Map properties) { - this.properties = requireNonNull(properties, "properties is null"); + this.properties = properties; return this; } - public Builder withRoles(Map roles) + public Builder roles(Map roles) { this.roles = roles; return this; } - public Builder withCredentials(Map credentials) + public Builder credentials(Map credentials) { - this.credentials = requireNonNull(credentials, "credentials is null"); + this.credentials = credentials; return this; } - public Builder withPreparedStatements(Map preparedStatements) + public Builder preparedStatements(Map preparedStatements) { - this.preparedStatements = requireNonNull(preparedStatements, "preparedStatements is null"); + this.preparedStatements = preparedStatements; return this; } - public Builder withTransactionId(String transactionId) + public Builder transactionId(String transactionId) { - this.transactionId = requireNonNull(transactionId, "transactionId is null"); + this.transactionId = transactionId; return this; } - public Builder withoutTransactionId() + public Builder clientRequestTimeout(Duration clientRequestTimeout) { - this.transactionId = null; + this.clientRequestTimeout = clientRequestTimeout; return this; } - public Builder withCompressionDisabled(boolean compressionDisabled) + public Builder compressionDisabled(boolean compressionDisabled) { this.compressionDisabled = compressionDisabled; return this; diff --git a/client/trino-jdbc/pom.xml b/client/trino-jdbc/pom.xml index 1bbcbab656af..ef88af985efa 100644 --- a/client/trino-jdbc/pom.xml +++ b/client/trino-jdbc/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index c841627b37f4..b2675de1180b 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -733,27 +733,27 @@ StatementClient startQuery(String sql, Map sessionPropertiesOver int millis = networkTimeoutMillis.get(); Duration timeout = (millis > 0) ? new Duration(millis, MILLISECONDS) : new Duration(999, DAYS); - ClientSession session = new ClientSession( - httpUri, - user, - sessionUser, - source, - Optional.ofNullable(clientInfo.get(TRACE_TOKEN)), - ImmutableSet.copyOf(clientTags), - clientInfo.get(CLIENT_INFO), - catalog.get(), - schema.get(), - path.get(), - timeZoneId.get(), - locale.get(), - ImmutableMap.of(), - ImmutableMap.copyOf(allProperties), - ImmutableMap.copyOf(preparedStatements), - ImmutableMap.copyOf(roles), - extraCredentials, - transactionId.get(), - timeout, - compressionDisabled); + ClientSession session = ClientSession.builder() + .server(httpUri) + .principal(user) + .user(sessionUser) + .source(source) + .traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN))) + .clientTags(ImmutableSet.copyOf(clientTags)) + .clientInfo(clientInfo.get(CLIENT_INFO)) + .catalog(catalog.get()) + .schema(schema.get()) + .path(path.get()) + .timeZone(timeZoneId.get()) + .locale(locale.get()) + .properties(ImmutableMap.copyOf(allProperties)) + .preparedStatements(ImmutableMap.copyOf(preparedStatements)) + .roles(ImmutableMap.copyOf(roles)) + .credentials(extraCredentials) + .transactionId(transactionId.get()) + .clientRequestTimeout(timeout) + .compressionDisabled(compressionDisabled) + .build(); return newStatementClient(httpClient, session, sql); } diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml index 24b6512d88d9..c59081575ce5 100644 --- a/core/trino-main/pom.xml +++ b/core/trino-main/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/core/trino-main/src/main/java/io/trino/cost/CostCalculatorWithEstimatedExchanges.java b/core/trino-main/src/main/java/io/trino/cost/CostCalculatorWithEstimatedExchanges.java index dfaeea21c274..58bc37d5c921 100644 --- a/core/trino-main/src/main/java/io/trino/cost/CostCalculatorWithEstimatedExchanges.java +++ b/core/trino-main/src/main/java/io/trino/cost/CostCalculatorWithEstimatedExchanges.java @@ -280,12 +280,10 @@ private static LocalCostEstimate calculateJoinExchangeCost( LocalCostEstimate localRepartitionCost = calculateLocalRepartitionCost(buildSizeInBytes); return addPartialComponents(replicateCost, localRepartitionCost); } - else { - LocalCostEstimate probeCost = calculateRemoteRepartitionCost(probeSizeInBytes); - LocalCostEstimate buildRemoteRepartitionCost = calculateRemoteRepartitionCost(buildSizeInBytes); - LocalCostEstimate buildLocalRepartitionCost = calculateLocalRepartitionCost(buildSizeInBytes); - return addPartialComponents(probeCost, buildRemoteRepartitionCost, buildLocalRepartitionCost); - } + LocalCostEstimate probeCost = calculateRemoteRepartitionCost(probeSizeInBytes); + LocalCostEstimate buildRemoteRepartitionCost = calculateRemoteRepartitionCost(buildSizeInBytes); + LocalCostEstimate buildLocalRepartitionCost = calculateLocalRepartitionCost(buildSizeInBytes); + return addPartialComponents(probeCost, buildRemoteRepartitionCost, buildLocalRepartitionCost); } public static LocalCostEstimate calculateJoinInputCost( diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index ec6ab4c6b196..9075b27d6e23 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -312,20 +312,18 @@ private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate left, SymbolSta if (left.getNullsFraction() == 0) { return left; } - else if (left.getNullsFraction() == 1.0) { + if (left.getNullsFraction() == 1.0) { return right; } - else { - return SymbolStatsEstimate.builder() - .setLowValue(min(left.getLowValue(), right.getLowValue())) - .setHighValue(max(left.getHighValue(), right.getHighValue())) - .setDistinctValuesCount(left.getDistinctValuesCount() + - min(right.getDistinctValuesCount(), input.getOutputRowCount() * left.getNullsFraction())) - .setNullsFraction(left.getNullsFraction() * right.getNullsFraction()) - // TODO check if dataSize estimation method is correct - .setAverageRowSize(max(left.getAverageRowSize(), right.getAverageRowSize())) - .build(); - } + return SymbolStatsEstimate.builder() + .setLowValue(min(left.getLowValue(), right.getLowValue())) + .setHighValue(max(left.getHighValue(), right.getHighValue())) + .setDistinctValuesCount(left.getDistinctValuesCount() + + min(right.getDistinctValuesCount(), input.getOutputRowCount() * left.getNullsFraction())) + .setNullsFraction(left.getNullsFraction() * right.getNullsFraction()) + // TODO check if dataSize estimation method is correct + .setAverageRowSize(max(left.getAverageRowSize(), right.getAverageRowSize())) + .build(); } } diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java b/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java index 5f31eef50fd2..79e7a0c83a6b 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java @@ -391,15 +391,13 @@ public QueryResults getQueryResults(long token, UriInfo uriInfo) DispatchInfo.queued(NO_DURATION, NO_DURATION)); } - Optional dispatchInfo = dispatchManager.getDispatchInfo(queryId); - if (dispatchInfo.isEmpty()) { - // query should always be found, but it may have just been determined to be abandoned - throw new WebApplicationException(Response - .status(NOT_FOUND) - .build()); - } + DispatchInfo dispatchInfo = dispatchManager.getDispatchInfo(queryId) + // query should always be found, but it may have just been determined to be abandoned + .orElseThrow(() -> new WebApplicationException(Response + .status(NOT_FOUND) + .build())); - return createQueryResults(token + 1, uriInfo, dispatchInfo.get()); + return createQueryResults(token + 1, uriInfo, dispatchInfo); } public void cancel() diff --git a/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java b/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java index 6325ba91afb9..58e697d7bf71 100644 --- a/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java +++ b/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java @@ -116,6 +116,7 @@ public void addInput(ExchangeInput input) return; } ExchangeDataSource dataSource = delegate.get(); + boolean inputAdded = false; if (dataSource == null) { if (input instanceof DirectExchangeInput) { DirectExchangeClient client = directExchangeClientSupplier.get(queryId, exchangeId, systemMemoryContext, taskFailureListener, retryPolicy); @@ -126,7 +127,8 @@ else if (input instanceof SpoolingExchangeInput) { ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); List sourceHandles = spoolingExchangeInput.getExchangeSourceHandles(); ExchangeSource exchangeSource = exchangeManager.createSource(sourceHandles); - dataSource = new SpoolingExchangeDataSource(exchangeSource, sourceHandles, systemMemoryContext); + dataSource = new SpoolingExchangeDataSource(exchangeSource, systemMemoryContext); + inputAdded = true; } else { throw new IllegalArgumentException("Unexpected input: " + input); @@ -134,7 +136,9 @@ else if (input instanceof SpoolingExchangeInput) { delegate.set(dataSource); initialized = true; } - dataSource.addInput(input); + if (!inputAdded) { + dataSource.addInput(input); + } } if (initialized) { diff --git a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java index 51b5cc624444..62a44ea83c4f 100644 --- a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java +++ b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java @@ -13,18 +13,13 @@ */ package io.trino.exchange; -import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.memory.context.LocalMemoryContext; import io.trino.operator.OperatorInfo; import io.trino.spi.exchange.ExchangeSource; -import io.trino.spi.exchange.ExchangeSourceHandle; -import java.util.List; - -import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static java.util.Objects.requireNonNull; @@ -39,18 +34,13 @@ public class SpoolingExchangeDataSource // It doesn't have to be declared as volatile as the nullification of this variable doesn't have to be immediately visible to other threads. // However since close can be called at any moment this variable has to be accessed in a safe way (avoiding "check-then-use"). private ExchangeSource exchangeSource; - private final List exchangeSourceHandles; private final LocalMemoryContext systemMemoryContext; private volatile boolean closed; - public SpoolingExchangeDataSource( - ExchangeSource exchangeSource, - List exchangeSourceHandles, - LocalMemoryContext systemMemoryContext) + public SpoolingExchangeDataSource(ExchangeSource exchangeSource, LocalMemoryContext systemMemoryContext) { // this assignment is expected to be followed by an assignment of a final field to ensure safe publication this.exchangeSource = requireNonNull(exchangeSource, "exchangeSource is null"); - this.exchangeSourceHandles = ImmutableList.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -96,16 +86,7 @@ public ListenableFuture isBlocked() @Override public void addInput(ExchangeInput input) { - SpoolingExchangeInput exchangeInput = (SpoolingExchangeInput) input; - // Only a single input is expected when the spooling exchange is used. - // The engine adds the same input to every instance of the ExchangeOperator. - // Since the ExchangeDataSource is shared between ExchangeOperator instances - // the same input may be delivered multiple times. - checkState( - exchangeInput.getExchangeSourceHandles().equals(exchangeSourceHandles), - "exchange input is expected to contain an identical exchangeSourceHandles list: %s != %s", - exchangeInput.getExchangeSourceHandles(), - exchangeSourceHandles); + throw new UnsupportedOperationException("only a single input is expected"); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/CommentTask.java b/core/trino-main/src/main/java/io/trino/execution/CommentTask.java index bf408789aa9d..cd29320675dd 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CommentTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CommentTask.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Map; -import java.util.Optional; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; @@ -133,12 +132,10 @@ else if (metadata.getTableHandle(session, viewName).isPresent()) { private void commentOnColumn(Comment statement, Session session) { - Optional prefix = statement.getName().getPrefix(); - if (prefix.isEmpty()) { - throw semanticException(MISSING_TABLE, statement, "Table must be specified"); - } + QualifiedName prefix = statement.getName().getPrefix() + .orElseThrow(() -> semanticException(MISSING_TABLE, statement, "Table must be specified")); - QualifiedObjectName originalTableName = createQualifiedObjectName(session, statement, prefix.get()); + QualifiedObjectName originalTableName = createQualifiedObjectName(session, statement, prefix); RedirectionAwareTableHandle redirectionAwareTableHandle = metadata.getRedirectionAwareTableHandle(session, originalTableName); if (redirectionAwareTableHandle.getTableHandle().isEmpty()) { throw semanticException(TABLE_NOT_FOUND, statement, "Table does not exist: " + originalTableName); diff --git a/core/trino-main/src/main/java/io/trino/execution/CommitTask.java b/core/trino-main/src/main/java/io/trino/execution/CommitTask.java index f777048525c4..43438e9adb7b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CommitTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CommitTask.java @@ -54,10 +54,7 @@ public ListenableFuture execute( WarningCollector warningCollector) { Session session = stateMachine.getSession(); - if (session.getTransactionId().isEmpty()) { - throw new TrinoException(NOT_IN_TRANSACTION, "No transaction in progress"); - } - TransactionId transactionId = session.getTransactionId().get(); + TransactionId transactionId = session.getTransactionId().orElseThrow(() -> new TrinoException(NOT_IN_TRANSACTION, "No transaction in progress")); stateMachine.clearTransactionId(); return transactionManager.asyncCommit(transactionId); diff --git a/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java b/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java index 9aacb600f157..521926d8dd6b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java @@ -165,7 +165,7 @@ public void onFailure(Throwable throwable) } @Override - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { // DDL does not have an output } diff --git a/core/trino-main/src/main/java/io/trino/execution/DenyTask.java b/core/trino-main/src/main/java/io/trino/execution/DenyTask.java index 2a4cd0032a16..943c03d84210 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DenyTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/DenyTask.java @@ -19,7 +19,6 @@ import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.RedirectionAwareTableHandle; -import io.trino.metadata.TableHandle; import io.trino.security.AccessControl; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.security.Privilege; @@ -97,8 +96,7 @@ private static void executeDenyOnTable(Session session, Deny statement, Metadata { QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName()); RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, tableName); - Optional tableHandle = redirection.getTableHandle(); - if (tableHandle.isEmpty()) { + if (redirection.getTableHandle().isEmpty()) { throw semanticException(TABLE_NOT_FOUND, statement, "Table '%s' does not exist", tableName); } if (redirection.getRedirectedTableName().isPresent()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/GrantTask.java b/core/trino-main/src/main/java/io/trino/execution/GrantTask.java index ad5d2cbf2ba7..806912b3f8d9 100644 --- a/core/trino-main/src/main/java/io/trino/execution/GrantTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/GrantTask.java @@ -19,7 +19,6 @@ import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.RedirectionAwareTableHandle; -import io.trino.metadata.TableHandle; import io.trino.security.AccessControl; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.security.Privilege; @@ -101,8 +100,7 @@ private void executeGrantOnTable(Session session, Grant statement) { QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName()); RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, tableName); - Optional tableHandle = redirection.getTableHandle(); - if (tableHandle.isEmpty()) { + if (redirection.getTableHandle().isEmpty()) { throw semanticException(TABLE_NOT_FOUND, statement, "Table '%s' does not exist", tableName); } if (redirection.getRedirectedTableName().isPresent()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java index f0608f70b7ab..38c5a85358fa 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java @@ -28,6 +28,7 @@ import io.trino.sql.planner.Plan; import java.util.List; +import java.util.Queue; import java.util.function.Consumer; import static java.util.Objects.requireNonNull; @@ -41,7 +42,7 @@ public interface QueryExecution void addStateChangeListener(StateChangeListener stateChangeListener); - void addOutputInfoListener(Consumer listener); + void setOutputInfoListener(Consumer listener); void outputTaskFailed(TaskId taskId, Throwable failure); @@ -86,23 +87,23 @@ interface QueryExecutionFactory } /** - * Output schema and buffer URIs for query. The info will always contain column names and types. Buffer locations will always - * contain the full location set, but may be empty. Users of this data should keep a private copy of the seen buffers to - * handle out of order events from the listener. Once noMoreBufferLocations is set the locations will never change, and - * it is guaranteed that all previously sent locations are contained in the buffer locations. + * The info will always contain column names and types. + * The {@code inputsQueue} is shared between {@link QueryOutputInfo} instances. + * It is guaranteed that no new entries will be added to {@code inputsQueue} after {@link QueryOutputInfo} + * with {@link #isNoMoreInputs()} {@code == true} is created. */ class QueryOutputInfo { private final List columnNames; private final List columnTypes; - private final List inputs; + private final Queue inputsQueue; private final boolean noMoreInputs; - public QueryOutputInfo(List columnNames, List columnTypes, List inputs, boolean noMoreInputs) + public QueryOutputInfo(List columnNames, List columnTypes, Queue inputsQueue, boolean noMoreInputs) { this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")); this.columnTypes = ImmutableList.copyOf(requireNonNull(columnTypes, "columnTypes is null")); - this.inputs = ImmutableList.copyOf(requireNonNull(inputs, "inputs is null")); + this.inputsQueue = requireNonNull(inputsQueue, "inputsQueue is null"); this.noMoreInputs = noMoreInputs; } @@ -116,9 +117,15 @@ public List getColumnTypes() return columnTypes; } - public List getInputs() + public void drainInputs(Consumer consumer) { - return inputs; + while (true) { + ExchangeInput input = inputsQueue.poll(); + if (input == null) { + break; + } + consumer.accept(input); + } } public boolean isNoMoreInputs() diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManager.java b/core/trino-main/src/main/java/io/trino/execution/QueryManager.java index 1543d7c5dc39..c0d3c588e389 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManager.java @@ -33,7 +33,7 @@ public interface QueryManager * * @throws NoSuchElementException if query does not exist */ - void addOutputInfoListener(QueryId queryId, Consumer listener) + void setOutputInfoListener(QueryId queryId, Consumer listener) throws NoSuchElementException; /** diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index fdf0e81164c8..7eb106408385 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -64,8 +64,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -284,6 +286,7 @@ static QueryStateMachine beginWithTicker( QUERY_STATE_LOG.debug("Query %s is %s", queryStateMachine.getQueryId(), newState); if (newState.isDone()) { queryStateMachine.getSession().getTransactionId().ifPresent(transactionManager::trySetInactive); + queryStateMachine.getOutputManager().setQueryCompleted(); } }); @@ -711,9 +714,9 @@ private QueryStats getQueryStats(Optional rootStage, List operatorStatsSummary.build()); } - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { - outputManager.addOutputInfoListener(listener); + outputManager.setOutputInfoListener(listener); } public void addOutputTaskFailureListener(TaskFailureListener listener) @@ -1282,21 +1285,28 @@ private static QueryStats pruneQueryStats(QueryStats queryStats) ImmutableList.of()); // Remove the operator summaries as OperatorInfo (especially DirectExchangeClientStatus) can hold onto a large amount of memory } + private QueryOutputManager getOutputManager() + { + return outputManager; + } + public static class QueryOutputManager { private final Executor executor; @GuardedBy("this") - private final List> outputInfoListeners = new ArrayList<>(); + private Optional> listener = Optional.empty(); @GuardedBy("this") private List columnNames; @GuardedBy("this") private List columnTypes; @GuardedBy("this") - private final List inputs = new ArrayList<>(); - @GuardedBy("this") private boolean noMoreInputs; + @GuardedBy("this") + private boolean queryCompleted; + + private final Queue inputsQueue = new ConcurrentLinkedQueue<>(); @GuardedBy("this") private final Map outputTaskFailures = new HashMap<>(); @@ -1308,16 +1318,17 @@ public QueryOutputManager(Executor executor) this.executor = requireNonNull(executor, "executor is null"); } - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { requireNonNull(listener, "listener is null"); Optional queryOutputInfo; synchronized (this) { - outputInfoListeners.add(listener); + checkState(this.listener.isEmpty(), "listener is already set"); + this.listener = Optional.of(listener); queryOutputInfo = getQueryOutputInfo(); } - queryOutputInfo.ifPresent(info -> executor.execute(() -> listener.accept(info))); + fireStateChangedIfReady(queryOutputInfo, Optional.of(listener)); } public void setColumns(List columnNames, List columnTypes) @@ -1327,16 +1338,16 @@ public void setColumns(List columnNames, List columnTypes) checkArgument(columnNames.size() == columnTypes.size(), "columnNames and columnTypes must be the same size"); Optional queryOutputInfo; - List> outputInfoListeners; + Optional> listener; synchronized (this) { checkState(this.columnNames == null && this.columnTypes == null, "output fields already set"); this.columnNames = ImmutableList.copyOf(columnNames); this.columnTypes = ImmutableList.copyOf(columnTypes); queryOutputInfo = getQueryOutputInfo(); - outputInfoListeners = ImmutableList.copyOf(this.outputInfoListeners); + listener = this.listener; } - queryOutputInfo.ifPresent(info -> fireStateChanged(info, outputInfoListeners)); + fireStateChangedIfReady(queryOutputInfo, listener); } public void updateInputsForQueryResults(List newInputs, boolean noMoreInputs) @@ -1344,16 +1355,28 @@ public void updateInputsForQueryResults(List newInputs, boolean n requireNonNull(newInputs, "newInputs is null"); Optional queryOutputInfo; - List> outputInfoListeners; + Optional> listener; synchronized (this) { - // noMoreInputs can be set more than once - checkState(newInputs.isEmpty() || !this.noMoreInputs, "new inputs added after no more inputs set"); - inputs.addAll(newInputs); - this.noMoreInputs = noMoreInputs; + if (!queryCompleted) { + // noMoreInputs can be set more than once + checkState(newInputs.isEmpty() || !this.noMoreInputs, "new inputs added after no more inputs set"); + inputsQueue.addAll(newInputs); + this.noMoreInputs = noMoreInputs; + } queryOutputInfo = getQueryOutputInfo(); - outputInfoListeners = ImmutableList.copyOf(this.outputInfoListeners); + listener = this.listener; + } + fireStateChangedIfReady(queryOutputInfo, listener); + } + + public synchronized void setQueryCompleted() + { + if (queryCompleted) { + return; } - queryOutputInfo.ifPresent(info -> fireStateChanged(info, outputInfoListeners)); + queryCompleted = true; + inputsQueue.clear(); + noMoreInputs = true; } public void addOutputTaskFailureListener(TaskFailureListener listener) @@ -1387,14 +1410,15 @@ private synchronized Optional getQueryOutputInfo() if (columnNames == null || columnTypes == null) { return Optional.empty(); } - return Optional.of(new QueryOutputInfo(columnNames, columnTypes, inputs, noMoreInputs)); + return Optional.of(new QueryOutputInfo(columnNames, columnTypes, inputsQueue, noMoreInputs)); } - private void fireStateChanged(QueryOutputInfo queryOutputInfo, List> outputInfoListeners) + private void fireStateChangedIfReady(Optional info, Optional> listener) { - for (Consumer outputInfoListener : outputInfoListeners) { - executor.execute(() -> outputInfoListener.accept(queryOutputInfo)); + if (info.isEmpty() || listener.isEmpty()) { + return; } + executor.execute(() -> listener.get().accept(info.get())); } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/RevokeTask.java b/core/trino-main/src/main/java/io/trino/execution/RevokeTask.java index 43276b0c8e13..39adb347b01f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/RevokeTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/RevokeTask.java @@ -19,7 +19,6 @@ import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.RedirectionAwareTableHandle; -import io.trino.metadata.TableHandle; import io.trino.security.AccessControl; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.security.Privilege; @@ -101,8 +100,7 @@ private void executeRevokeOnTable(Session session, Revoke statement) { QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName()); RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, tableName); - Optional tableHandle = redirection.getTableHandle(); - if (tableHandle.isEmpty()) { + if (redirection.getTableHandle().isEmpty()) { throw semanticException(TABLE_NOT_FOUND, statement, "Table '%s' does not exist", tableName); } if (redirection.getRedirectedTableName().isPresent()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/RollbackTask.java b/core/trino-main/src/main/java/io/trino/execution/RollbackTask.java index 51318e9708f9..965018c050d7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/RollbackTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/RollbackTask.java @@ -55,10 +55,7 @@ public ListenableFuture execute( WarningCollector warningCollector) { Session session = stateMachine.getSession(); - if (session.getTransactionId().isEmpty()) { - throw new TrinoException(NOT_IN_TRANSACTION, "No transaction in progress"); - } - TransactionId transactionId = session.getTransactionId().get(); + TransactionId transactionId = session.getTransactionId().orElseThrow(() -> new TrinoException(NOT_IN_TRANSACTION, "No transaction in progress")); stateMachine.clearTransactionId(); transactionManager.asyncAbort(transactionId); diff --git a/core/trino-main/src/main/java/io/trino/execution/SetPropertiesTask.java b/core/trino-main/src/main/java/io/trino/execution/SetPropertiesTask.java index 32dbde1aa781..66f935284b01 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SetPropertiesTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SetPropertiesTask.java @@ -118,13 +118,11 @@ private void setTableProperties(SetProperties statement, QualifiedObjectName tab throw semanticException(NOT_SUPPORTED, statement, "Cannot set properties to a view in ALTER TABLE"); } - Optional tableHandle = plannerContext.getMetadata().getTableHandle(session, tableName); - if (tableHandle.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, statement, "Table does not exist: %s", tableName); - } + TableHandle tableHandle = plannerContext.getMetadata().getTableHandle(session, tableName) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, statement, "Table does not exist: %s", tableName)); accessControl.checkCanSetTableProperties(session.toSecurityContext(), tableName, properties); - plannerContext.getMetadata().setTableProperties(session, tableHandle.get(), properties); + plannerContext.getMetadata().setTableProperties(session, tableHandle, properties); } private void setMaterializedViewProperties( diff --git a/core/trino-main/src/main/java/io/trino/execution/SetTableAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/SetTableAuthorizationTask.java index aaceea09beed..651642269ce7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SetTableAuthorizationTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SetTableAuthorizationTask.java @@ -19,7 +19,6 @@ import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.RedirectionAwareTableHandle; -import io.trino.metadata.TableHandle; import io.trino.security.AccessControl; import io.trino.spi.security.TrinoPrincipal; import io.trino.sql.tree.Expression; @@ -71,8 +70,7 @@ public ListenableFuture execute( getRequiredCatalogHandle(metadata, session, statement, tableName.getCatalogName()); RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, tableName); - Optional tableHandle = redirection.getTableHandle(); - if (tableHandle.isEmpty()) { + if (redirection.getTableHandle().isEmpty()) { throw semanticException(TABLE_NOT_FOUND, statement, "Table '%s' does not exist", tableName); } if (redirection.getRedirectedTableName().isPresent()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java b/core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java index d2b7cf250760..77cf862590d8 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java +++ b/core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java @@ -78,10 +78,8 @@ public SplitAssignment update(SplitAssignment assignment) newSplits, assignment.isNoMoreSplits()); } - else { - // the specified assignment is older than this one - return this; - } + // the specified assignment is older than this one + return this; } private boolean isNewer(SplitAssignment assignment) diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 1f60d3b5bc82..55c88d341350 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -598,9 +598,9 @@ public boolean isDone() } @Override - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { - stateMachine.addOutputInfoListener(listener); + stateMachine.setOutputInfoListener(listener); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java index f04472c2f06c..248cd6a6e7b2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java @@ -148,11 +148,11 @@ public List getQueries() } @Override - public void addOutputInfoListener(QueryId queryId, Consumer listener) + public void setOutputInfoListener(QueryId queryId, Consumer listener) { requireNonNull(listener, "listener is null"); - queryTracker.getQuery(queryId).addOutputInfoListener(listener); + queryTracker.getQuery(queryId).setOutputInfoListener(listener); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index 3d2819708ff6..c8e290eecc7e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -448,13 +448,12 @@ public TaskInfo updateTask( this::notifyStatusChanged); taskHolderReference.compareAndSet(taskHolder, new TaskHolder(taskExecution)); needsPlan.set(false); + taskExecution.start(); } } - if (taskExecution != null) { - taskExecution.addSplitAssignments(splitAssignments); - taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains); - } + taskExecution.addSplitAssignments(splitAssignments); + taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains); } catch (Error e) { failed(e); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 33f2bf84b6d0..a104d0a6b527 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -42,22 +41,22 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; import java.lang.ref.WeakReference; import java.util.ArrayList; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalInt; +import java.util.Queue; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import static com.google.common.base.MoreObjects.toStringHelper; @@ -65,7 +64,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.concat; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.trino.SystemSessionProperties.getInitialSplitsPerNode; import static io.trino.SystemSessionProperties.getMaxDriversPerTask; @@ -93,14 +92,9 @@ public class SqlTaskExecution private final SplitMonitor splitMonitor; - private final List> drivers = new CopyOnWriteArrayList<>(); - private final Map driverRunnerFactoriesWithSplitLifeCycle; private final List driverRunnerFactoriesWithTaskLifeCycle; - - // guarded for update only - @GuardedBy("this") - private final ConcurrentMap unpartitionedSplitAssignments = new ConcurrentHashMap<>(); + private final Map driverRunnerFactoriesWithRemoteSource; @GuardedBy("this") private long maxAcknowledgedSplit = Long.MIN_VALUE; @@ -113,34 +107,10 @@ public class SqlTaskExecution @GuardedBy("this") private final Map pendingSplitsByPlanNode; - private final Status status; + // number of created Drivers that haven't yet finished + private final AtomicLong remainingDrivers = new AtomicLong(); - static SqlTaskExecution createSqlTaskExecution( - TaskStateMachine taskStateMachine, - TaskContext taskContext, - OutputBuffer outputBuffer, - LocalExecutionPlan localExecutionPlan, - TaskExecutor taskExecutor, - Executor notificationExecutor, - SplitMonitor queryMonitor) - { - SqlTaskExecution task = new SqlTaskExecution( - taskStateMachine, - taskContext, - outputBuffer, - localExecutionPlan, - taskExecutor, - queryMonitor, - notificationExecutor); - try (SetThreadName ignored = new SetThreadName("Task-%s", task.getTaskId())) { - // The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object. - // The call back is accessed from another thread, so this code cannot be placed in the constructor. - task.scheduleDriversForTaskLifeCycle(); - return task; - } - } - - private SqlTaskExecution( + public SqlTaskExecution( TaskStateMachine taskStateMachine, TaskContext taskContext, OutputBuffer outputBuffer, @@ -164,24 +134,24 @@ private SqlTaskExecution( Set partitionedSources = ImmutableSet.copyOf(localExecutionPlan.getPartitionedSourceOrder()); ImmutableMap.Builder driverRunnerFactoriesWithSplitLifeCycle = ImmutableMap.builder(); ImmutableList.Builder driverRunnerFactoriesWithTaskLifeCycle = ImmutableList.builder(); + ImmutableMap.Builder driverRunnerFactoriesWithRemoteSource = ImmutableMap.builder(); for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) { Optional sourceId = driverFactory.getSourceId(); if (sourceId.isPresent() && partitionedSources.contains(sourceId.get())) { driverRunnerFactoriesWithSplitLifeCycle.put(sourceId.get(), new DriverSplitRunnerFactory(driverFactory, true)); } else { - driverRunnerFactoriesWithTaskLifeCycle.add(new DriverSplitRunnerFactory(driverFactory, false)); + DriverSplitRunnerFactory runnerFactory = new DriverSplitRunnerFactory(driverFactory, false); + sourceId.ifPresent(planNodeId -> driverRunnerFactoriesWithRemoteSource.put(planNodeId, runnerFactory)); + driverRunnerFactoriesWithTaskLifeCycle.add(runnerFactory); } } this.driverRunnerFactoriesWithSplitLifeCycle = driverRunnerFactoriesWithSplitLifeCycle.buildOrThrow(); this.driverRunnerFactoriesWithTaskLifeCycle = driverRunnerFactoriesWithTaskLifeCycle.build(); + this.driverRunnerFactoriesWithRemoteSource = driverRunnerFactoriesWithRemoteSource.buildOrThrow(); this.pendingSplitsByPlanNode = this.driverRunnerFactoriesWithSplitLifeCycle.keySet().stream() .collect(toImmutableMap(identity(), ignore -> new PendingSplitsForPlanNode())); - this.status = new Status( - localExecutionPlan.getDriverFactories().stream() - .map(DriverFactory::getPipelineId) - .collect(toImmutableSet())); sourceStartOrder = localExecutionPlan.getPartitionedSourceOrder(); checkArgument(this.driverRunnerFactoriesWithSplitLifeCycle.keySet().equals(partitionedSources), @@ -199,6 +169,15 @@ private SqlTaskExecution( } } + public void start() + { + try (SetThreadName ignored = new SetThreadName("Task-%s", getTaskId())) { + // The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object. + // The call back is accessed from another thread, so this code cannot be placed in the constructor. + scheduleDriversForTaskLifeCycle(); + } + } + // this is a separate method to ensure that the `this` reference is not leaked during construction private static TaskHandle createTaskHandle( TaskStateMachine taskStateMachine, @@ -241,39 +220,20 @@ public void addSplitAssignments(List splitAssignments) try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { // update our record of split assignments and schedule drivers for new partitioned splits - Map updatedUnpartitionedSources = updateSplitAssignments(splitAssignments); - - // tell existing drivers about the new splits; it is safe to update drivers - // multiple times and out of order because split assignments contain full record of - // the unpartitioned splits - for (WeakReference driverReference : drivers) { - Driver driver = driverReference.get(); - // the driver can be GCed due to a failure or a limit - if (driver == null) { - // remove the weak reference from the list to avoid a memory leak - // NOTE: this is a concurrent safe operation on a CopyOnWriteArrayList - drivers.remove(driverReference); - continue; - } - Optional sourceId = driver.getSourceId(); - if (sourceId.isEmpty()) { - continue; - } - SplitAssignment splitAssignmentUpdate = updatedUnpartitionedSources.get(sourceId.get()); - if (splitAssignmentUpdate == null) { - continue; - } - driver.updateSplitAssignment(splitAssignmentUpdate); + Set updatedUnpartitionedSources = updateSplitAssignments(splitAssignments); + for (PlanNodeId planNodeId : updatedUnpartitionedSources) { + DriverSplitRunnerFactory factory = driverRunnerFactoriesWithRemoteSource.get(planNodeId); + // schedule splits outside the lock + factory.scheduleSplits(); } - // we may have transitioned to no more splits, so check for completion checkTaskCompletion(); } } - private synchronized Map updateSplitAssignments(List splitAssignments) + private synchronized Set updateSplitAssignments(List splitAssignments) { - Map updatedUnpartitionedSplitAssignments = new HashMap<>(); + ImmutableSet.Builder updatedUnpartitionedSources = ImmutableSet.builder(); // first remove any split that was already acknowledged long currentMaxAcknowledgedSplit = this.maxAcknowledgedSplit; @@ -292,22 +252,20 @@ private synchronized Map updateSplitAssignments(Lis schedulePartitionedSource(assignment); } else { - scheduleUnpartitionedSource(assignment, updatedUnpartitionedSplitAssignments); + // tell existing drivers about the new splits + DriverSplitRunnerFactory factory = driverRunnerFactoriesWithRemoteSource.get(assignment.getPlanNodeId()); + factory.enqueueSplits(assignment.getSplits(), assignment.isNoMoreSplits()); + updatedUnpartitionedSources.add(assignment.getPlanNodeId()); } } - for (DriverSplitRunnerFactory driverSplitRunnerFactory : - Iterables.concat(driverRunnerFactoriesWithSplitLifeCycle.values(), driverRunnerFactoriesWithTaskLifeCycle)) { - driverSplitRunnerFactory.closeDriverFactoryIfFullyCreated(); - } - // update maxAcknowledgedSplit maxAcknowledgedSplit = splitAssignments.stream() .flatMap(source -> source.getSplits().stream()) .mapToLong(ScheduledSplit::getSequenceId) .max() .orElse(maxAcknowledgedSplit); - return updatedUnpartitionedSplitAssignments; + return updatedUnpartitionedSources.build(); } @GuardedBy("this") @@ -358,25 +316,6 @@ private synchronized void schedulePartitionedSource(SplitAssignment splitAssignm } } - private synchronized void scheduleUnpartitionedSource(SplitAssignment splitAssignmentUpdate, Map updatedUnpartitionedSources) - { - // create new source - SplitAssignment newSplitAssignment; - SplitAssignment currentSplitAssignment = unpartitionedSplitAssignments.get(splitAssignmentUpdate.getPlanNodeId()); - if (currentSplitAssignment == null) { - newSplitAssignment = splitAssignmentUpdate; - } - else { - newSplitAssignment = currentSplitAssignment.update(splitAssignmentUpdate); - } - - // only record new source if something changed - if (newSplitAssignment != currentSplitAssignment) { - unpartitionedSplitAssignments.put(splitAssignmentUpdate.getPlanNodeId(), newSplitAssignment); - updatedUnpartitionedSources.put(splitAssignmentUpdate.getPlanNodeId(), newSplitAssignment); - } - } - private void scheduleDriversForTaskLifeCycle() { // This method is called at the beginning of the task. @@ -392,6 +331,7 @@ private void scheduleDriversForTaskLifeCycle() driverRunnerFactory.noMoreDriverRunner(); verify(driverRunnerFactory.isNoMoreDriverRunner()); } + checkTaskCompletion(); } private synchronized void enqueueDriverSplitRunner(boolean forceRunSplit, List runners) @@ -406,7 +346,7 @@ private synchronized void enqueueDriverSplitRunner(boolean forceRunSplit, List() { @@ -415,7 +355,7 @@ public void onSuccess(Object result) { try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { // record driver is finished - status.decrementRemainingDriver(); + remainingDrivers.decrementAndGet(); checkTaskCompletion(); @@ -430,7 +370,7 @@ public void onFailure(Throwable cause) taskStateMachine.failed(cause); // record driver is finished - status.decrementRemainingDriver(); + remainingDrivers.decrementAndGet(); // fire failed event with cause splitMonitor.splitFailedEvent(taskId, getDriverStats(), cause); @@ -458,14 +398,14 @@ private DriverStats getDriverStats() public synchronized Set getNoMoreSplits() { ImmutableSet.Builder noMoreSplits = ImmutableSet.builder(); - for (Entry entry : driverRunnerFactoriesWithSplitLifeCycle.entrySet()) { + for (Map.Entry entry : driverRunnerFactoriesWithSplitLifeCycle.entrySet()) { if (entry.getValue().isNoMoreDriverRunner()) { noMoreSplits.add(entry.getKey()); } } - for (SplitAssignment splitAssignment : unpartitionedSplitAssignments.values()) { - if (splitAssignment.isNoMoreSplits()) { - noMoreSplits.add(splitAssignment.getPlanNodeId()); + for (Map.Entry entry : driverRunnerFactoriesWithRemoteSource.entrySet()) { + if (entry.getValue().isNoMoreSplits()) { + noMoreSplits.add(entry.getKey()); } } return noMoreSplits.build(); @@ -477,14 +417,14 @@ private synchronized void checkTaskCompletion() return; } - // are there more partition splits expected? - for (DriverSplitRunnerFactory driverSplitRunnerFactory : driverRunnerFactoriesWithSplitLifeCycle.values()) { - if (!driverSplitRunnerFactory.isNoMoreDriverRunner()) { + // are there more drivers expected? + for (DriverSplitRunnerFactory driverSplitRunnerFactory : concat(driverRunnerFactoriesWithTaskLifeCycle, driverRunnerFactoriesWithSplitLifeCycle.values())) { + if (!driverSplitRunnerFactory.isNoMoreDrivers()) { return; } } // do we still have running tasks? - if (status.getRemainingDriver() != 0) { + if (remainingDrivers.get() != 0) { return; } @@ -520,8 +460,7 @@ public String toString() { return toStringHelper(this) .add("taskId", taskId) - .add("remainingDrivers", status.getRemainingDriver()) - .add("unpartitionedSplitAssignments", unpartitionedSplitAssignments) + .add("remainingDrivers", remainingDrivers.get()) .toString(); } @@ -595,7 +534,16 @@ private class DriverSplitRunnerFactory { private final DriverFactory driverFactory; private final PipelineContext pipelineContext; - private boolean closed; + + // number of created DriverSplitRunners that haven't created underlying Driver + private final AtomicInteger pendingCreations = new AtomicInteger(); + // true if no more DriverSplitRunners will be created + private final AtomicBoolean noMoreDriverRunner = new AtomicBoolean(); + + private final List> driverReferences = new CopyOnWriteArrayList<>(); + private final Queue queuedSplits = new ConcurrentLinkedQueue<>(); + private final AtomicLong inFlightSplits = new AtomicLong(); + private final AtomicBoolean noMoreSplits = new AtomicBoolean(); private DriverSplitRunnerFactory(DriverFactory driverFactory, boolean partitioned) { @@ -607,7 +555,8 @@ private DriverSplitRunnerFactory(DriverFactory driverFactory, boolean partitione // The former will take two arguments, and the latter will take one. This will simplify the signature quite a bit. public DriverSplitRunner createDriverRunner(@Nullable ScheduledSplit partitionedSplit) { - status.incrementPendingCreation(pipelineContext.getPipelineId()); + checkState(!noMoreDriverRunner.get(), "noMoreDriverRunner is set"); + pendingCreations.incrementAndGet(); // create driver context immediately so the driver existence is recorded in the stats // the number of drivers is used to balance work across nodes long splitWeight = partitionedSplit == null ? 0 : partitionedSplit.getSplit().getSplitWeight().getRawValue(); @@ -619,51 +568,101 @@ public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit { Driver driver = driverFactory.createDriver(driverContext); - // record driver so other threads add unpartitioned sources can see the driver - // NOTE: this MUST be done before reading unpartitionedSources, so we see a consistent view of the unpartitioned sources - drivers.add(new WeakReference<>(driver)); - if (partitionedSplit != null) { // TableScanOperator requires partitioned split to be added before the first call to process driver.updateSplitAssignment(new SplitAssignment(partitionedSplit.getPlanNodeId(), ImmutableSet.of(partitionedSplit), true)); } - // add unpartitioned sources - Optional sourceId = driver.getSourceId(); - if (sourceId.isPresent()) { - SplitAssignment splitAssignment = unpartitionedSplitAssignments.get(sourceId.get()); - if (splitAssignment != null) { - driver.updateSplitAssignment(splitAssignment); + pendingCreations.decrementAndGet(); + closeDriverFactoryIfFullyCreated(); + + if (driverFactory.getSourceId().isPresent() && partitionedSplit == null) { + driverReferences.add(new WeakReference<>(driver)); + scheduleSplits(); + } + + return driver; + } + + public void enqueueSplits(Set splits, boolean noMoreSplits) + { + verify(driverFactory.getSourceId().isPresent(), "not a source driver"); + verify(!this.noMoreSplits.get() || splits.isEmpty(), "cannot add splits after noMoreSplits is set"); + queuedSplits.addAll(splits); + verify(!this.noMoreSplits.get() || noMoreSplits, "cannot unset noMoreSplits"); + if (noMoreSplits) { + this.noMoreSplits.set(true); + } + } + + public void scheduleSplits() + { + if (driverReferences.isEmpty()) { + return; + } + + PlanNodeId sourceId = driverFactory.getSourceId().orElseThrow(); + while (!queuedSplits.isEmpty()) { + int activeDriversCount = 0; + for (WeakReference driverReference : driverReferences) { + Driver driver = driverReference.get(); + if (driver == null) { + continue; + } + activeDriversCount++; + inFlightSplits.incrementAndGet(); + ScheduledSplit split = queuedSplits.poll(); + if (split == null) { + inFlightSplits.decrementAndGet(); + break; + } + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(split), false)); + inFlightSplits.decrementAndGet(); + } + if (activeDriversCount == 0) { + break; } } - status.decrementPendingCreation(pipelineContext.getPipelineId()); - closeDriverFactoryIfFullyCreated(); + if (noMoreSplits.get() && queuedSplits.isEmpty() && inFlightSplits.get() == 0) { + for (WeakReference driverReference : driverReferences) { + Driver driver = driverReference.get(); + if (driver != null) { + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(), true)); + } + } + } + } - return driver; + public boolean isNoMoreSplits() + { + return noMoreSplits.get(); } public void noMoreDriverRunner() { - status.setNoMoreDriverRunner(pipelineContext.getPipelineId()); + noMoreDriverRunner.set(true); closeDriverFactoryIfFullyCreated(); } public boolean isNoMoreDriverRunner() { - return status.isNoMoreDriverRunners(pipelineContext.getPipelineId()); + return noMoreDriverRunner.get(); } public void closeDriverFactoryIfFullyCreated() { - if (closed) { + if (driverFactory.isNoMoreDrivers()) { return; } - if (!isNoMoreDriverRunner() || status.getPendingCreation(pipelineContext.getPipelineId()) != 0) { - return; + if (isNoMoreDriverRunner() && pendingCreations.get() == 0) { + driverFactory.noMoreDrivers(); } - driverFactory.noMoreDrivers(); - closed = true; + } + + public boolean isNoMoreDrivers() + { + return driverFactory.isNoMoreDrivers(); } public OptionalInt getDriverInstances() @@ -780,94 +779,4 @@ public void stateChanged(BufferState newState) } } } - - @ThreadSafe - private static class Status - { - // no more driver runner: true if no more DriverSplitRunners will be created. - // pending creation: number of created DriverSplitRunners that haven't created underlying Driver. - // remaining driver: number of created Drivers that haven't yet finished. - - @GuardedBy("this") - private final int pipelineWithTaskLifeCycleCount; - - // For these 3 perX fields, they are populated lazily. If enumeration operations on the - // map can lead to side effects, no new entries can be created after such enumeration has - // happened. Otherwise, the order of entry creation and the enumeration operation will - // lead to different outcome. - @GuardedBy("this") - private final Map perPipeline; - @GuardedBy("this") - int pipelinesWithNoMoreDriverRunners; - - @GuardedBy("this") - private int overallRemainingDriver; - - public Status(Set pipelineIds) - { - int pipelineWithTaskLifeCycleCount = 0; - ImmutableMap.Builder perPipeline = ImmutableMap.builder(); - for (int pipelineId : pipelineIds) { - perPipeline.put(pipelineId, new PerPipelineStatus()); - pipelineWithTaskLifeCycleCount++; - } - this.pipelineWithTaskLifeCycleCount = pipelineWithTaskLifeCycleCount; - this.perPipeline = perPipeline.buildOrThrow(); - } - - public synchronized void setNoMoreDriverRunner(int pipelineId) - { - per(pipelineId).noMoreDriverRunners = true; - pipelinesWithNoMoreDriverRunners++; - } - - public synchronized void incrementPendingCreation(int pipelineId) - { - per(pipelineId).pendingCreation++; - } - - public synchronized void decrementPendingCreation(int pipelineId) - { - per(pipelineId).pendingCreation--; - } - - public synchronized void incrementRemainingDriver() - { - checkState(!(pipelinesWithNoMoreDriverRunners == pipelineWithTaskLifeCycleCount), "Cannot increment remainingDriver. NoMoreSplits is set."); - overallRemainingDriver++; - } - - public synchronized void decrementRemainingDriver() - { - checkState(overallRemainingDriver > 0, "Cannot decrement remainingDriver. Value is 0."); - overallRemainingDriver--; - } - - public synchronized int getPendingCreation(int pipelineId) - { - return per(pipelineId).pendingCreation; - } - - public synchronized int getRemainingDriver() - { - return overallRemainingDriver; - } - - public synchronized boolean isNoMoreDriverRunners(int pipelineId) - { - return per(pipelineId).noMoreDriverRunners; - } - - @GuardedBy("this") - private PerPipelineStatus per(int pipelineId) - { - return perPipeline.get(pipelineId); - } - } - - private static class PerPipelineStatus - { - int pendingCreation; - boolean noMoreDriverRunners; - } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java index 735f9319ba7e..e501684f0ebf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java @@ -28,7 +28,6 @@ import java.util.concurrent.Executor; import static com.google.common.base.Throwables.throwIfUnchecked; -import static io.trino.execution.SqlTaskExecution.createSqlTaskExecution; import static java.util.Objects.requireNonNull; public class SqlTaskExecutionFactory @@ -91,13 +90,13 @@ public SqlTaskExecution create( throw new RuntimeException(e); } } - return createSqlTaskExecution( + return new SqlTaskExecution( taskStateMachine, taskContext, outputBuffer, localExecutionPlan, taskExecutor, - taskNotificationExecutor, - splitMonitor); + splitMonitor, + taskNotificationExecutor); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java index 1ac8e0e564d5..6dc8479fce2f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java @@ -24,11 +24,13 @@ import io.trino.operator.OperatorStats; import io.trino.operator.PipelineStats; import io.trino.operator.TaskStats; +import io.trino.plugin.base.metrics.TDigestHistogram; import io.trino.spi.eventlistener.StageGcStatistics; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.TableScanNode; import io.trino.util.Failures; +import io.trino.util.Optionals; import org.joda.time.DateTime; import javax.annotation.concurrent.ThreadSafe; @@ -420,6 +422,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) long failedInputBlockedTime = 0; long bufferedDataSize = 0; + Optional outputBufferUtilization = Optional.empty(); long outputDataSize = 0; long failedOutputDataSize = 0; long outputPositions = 0; @@ -495,6 +498,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) inputBlockedTime += taskStats.getInputBlockedTime().roundTo(NANOSECONDS); bufferedDataSize += taskInfo.getOutputBuffers().getTotalBufferedBytes(); + outputBufferUtilization = Optionals.combine(outputBufferUtilization, taskInfo.getOutputBuffers().getUtilization(), TDigestHistogram::mergeWith); outputDataSize += taskStats.getOutputDataSize().toBytes(); outputPositions += taskStats.getOutputPositions(); @@ -596,6 +600,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) succinctDuration(inputBlockedTime, NANOSECONDS), succinctDuration(failedInputBlockedTime, NANOSECONDS), succinctBytes(bufferedDataSize), + outputBufferUtilization, succinctBytes(outputDataSize), succinctBytes(failedOutputDataSize), outputPositions, diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStats.java b/core/trino-main/src/main/java/io/trino/execution/StageStats.java index b783283c9f71..60fa4a84bf3b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStats.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStats.java @@ -22,12 +22,14 @@ import io.airlift.units.Duration; import io.trino.operator.BlockedReason; import io.trino.operator.OperatorStats; +import io.trino.plugin.base.metrics.TDigestHistogram; import io.trino.spi.eventlistener.StageGcStatistics; import org.joda.time.DateTime; import javax.annotation.concurrent.Immutable; import java.util.List; +import java.util.Optional; import java.util.OptionalDouble; import java.util.Set; @@ -96,6 +98,7 @@ public class StageStats private final Duration failedInputBlockedTime; private final DataSize bufferedDataSize; + private final Optional outputBufferUtilization; private final DataSize outputDataSize; private final DataSize failedOutputDataSize; private final long outputPositions; @@ -170,6 +173,7 @@ public StageStats( @JsonProperty("failedInputBlockedTime") Duration failedInputBlockedTime, @JsonProperty("bufferedDataSize") DataSize bufferedDataSize, + @JsonProperty("outputBufferUtilization") Optional outputBufferUtilization, @JsonProperty("outputDataSize") DataSize outputDataSize, @JsonProperty("failedOutputDataSize") DataSize failedOutputDataSize, @JsonProperty("outputPositions") long outputPositions, @@ -258,6 +262,7 @@ public StageStats( this.failedInputBlockedTime = requireNonNull(failedInputBlockedTime, "failedInputBlockedTime is null"); this.bufferedDataSize = requireNonNull(bufferedDataSize, "bufferedDataSize is null"); + this.outputBufferUtilization = requireNonNull(outputBufferUtilization, "outputBufferUtilization is null"); this.outputDataSize = requireNonNull(outputDataSize, "outputDataSize is null"); this.failedOutputDataSize = requireNonNull(failedOutputDataSize, "failedOutputDataSize is null"); checkArgument(outputPositions >= 0, "outputPositions is negative"); @@ -552,6 +557,12 @@ public DataSize getBufferedDataSize() return bufferedDataSize; } + @JsonProperty + public Optional getOutputBufferUtilization() + { + return outputBufferUtilization; + } + @JsonProperty public DataSize getOutputDataSize() { diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskInfo.java b/core/trino-main/src/main/java/io/trino/execution/TaskInfo.java index 45721bba43f6..1558095b53f2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskInfo.java @@ -111,7 +111,7 @@ public boolean isNeedsPlan() public TaskInfo summarize() { if (taskStatus.getState().isDone()) { - return new TaskInfo(taskStatus, lastHeartbeat, outputBuffers.summarize(), noMoreSplits, stats.summarizeFinal(), estimatedMemory, needsPlan); + return new TaskInfo(taskStatus, lastHeartbeat, outputBuffers.summarizeFinal(), noMoreSplits, stats.summarizeFinal(), estimatedMemory, needsPlan); } return new TaskInfo(taskStatus, lastHeartbeat, outputBuffers.summarize(), noMoreSplits, stats.summarize(), estimatedMemory, needsPlan); } @@ -130,7 +130,7 @@ public static TaskInfo createInitialTask(TaskId taskId, URI location, String nod return new TaskInfo( initialTaskStatus(taskId, location, nodeId), DateTime.now(), - new OutputBufferInfo("UNINITIALIZED", OPEN, true, true, 0, 0, 0, 0, bufferStates), + new OutputBufferInfo("UNINITIALIZED", OPEN, true, true, 0, 0, 0, 0, bufferStates, Optional.empty()), ImmutableSet.of(), taskStats, Optional.empty(), diff --git a/core/trino-main/src/main/java/io/trino/execution/TruncateTableTask.java b/core/trino-main/src/main/java/io/trino/execution/TruncateTableTask.java index 1303913033b8..f5ed8e86eb60 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TruncateTableTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/TruncateTableTask.java @@ -26,7 +26,6 @@ import javax.inject.Inject; import java.util.List; -import java.util.Optional; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; @@ -72,14 +71,12 @@ public ListenableFuture execute( throw semanticException(NOT_SUPPORTED, statement, "Cannot truncate a view"); } - Optional tableHandle = metadata.getTableHandle(session, tableName); - if (tableHandle.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, statement, "Table '%s' does not exist", tableName); - } + TableHandle tableHandle = metadata.getTableHandle(session, tableName) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, statement, "Table '%s' does not exist", tableName)); accessControl.checkCanTruncateTable(session.toSecurityContext(), tableName); - metadata.truncateTable(session, tableHandle.get()); + metadata.truncateTable(session, tableHandle); return immediateFuture(null); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java index 2e7799e45ef5..d3262645c7c1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java @@ -25,6 +25,7 @@ import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.buffer.SerializedPageReference.PagesReleasedListener; import io.trino.memory.context.LocalMemoryContext; +import io.trino.plugin.base.metrics.TDigestHistogram; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -150,7 +151,8 @@ public OutputBufferInfo getInfo() totalBufferedPages, totalRowsAdded.get(), totalPagesAdded.get(), - infos.build()); + infos.build(), + Optional.of(new TDigestHistogram(memoryManager.getUtilizationHistogram()))); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java index 95dbff4a4d03..5b40b9ac447b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java @@ -24,6 +24,7 @@ import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.buffer.SerializedPageReference.PagesReleasedListener; import io.trino.memory.context.LocalMemoryContext; +import io.trino.plugin.base.metrics.TDigestHistogram; import javax.annotation.concurrent.GuardedBy; @@ -142,7 +143,8 @@ public OutputBufferInfo getInfo() totalPagesAdded.get(), buffers.stream() .map(ClientBuffer::getInfo) - .collect(toImmutableList())); + .collect(toImmutableList()), + Optional.of(new TDigestHistogram(memoryManager.getUtilizationHistogram()))); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java index b9f26a3412fa..9f257b50499a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java @@ -137,7 +137,8 @@ public OutputBufferInfo getInfo() 0, 0, 0, - ImmutableList.of()); + ImmutableList.of(), + Optional.empty()); } return outputBuffer.getInfo(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferInfo.java b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferInfo.java index f956fa150c8a..ca4371edeb24 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferInfo.java @@ -16,9 +16,11 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import io.trino.plugin.base.metrics.TDigestHistogram; import java.util.List; import java.util.Objects; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -33,6 +35,7 @@ public final class OutputBufferInfo private final long totalRowsSent; private final long totalPagesSent; private final List buffers; + private final Optional utilization; @JsonCreator public OutputBufferInfo( @@ -44,7 +47,8 @@ public OutputBufferInfo( @JsonProperty("totalBufferedPages") long totalBufferedPages, @JsonProperty("totalRowsSent") long totalRowsSent, @JsonProperty("totalPagesSent") long totalPagesSent, - @JsonProperty("buffers") List buffers) + @JsonProperty("buffers") List buffers, + @JsonProperty("utilization") Optional utilization) { this.type = type; this.state = state; @@ -55,6 +59,7 @@ public OutputBufferInfo( this.totalRowsSent = totalRowsSent; this.totalPagesSent = totalPagesSent; this.buffers = ImmutableList.copyOf(buffers); + this.utilization = utilization; } @JsonProperty @@ -111,9 +116,20 @@ public long getTotalPagesSent() return totalPagesSent; } + @JsonProperty + public Optional getUtilization() + { + return utilization; + } + public OutputBufferInfo summarize() { - return new OutputBufferInfo(type, state, canAddBuffers, canAddPages, totalBufferedBytes, totalBufferedPages, totalRowsSent, totalPagesSent, ImmutableList.of()); + return new OutputBufferInfo(type, state, canAddBuffers, canAddPages, totalBufferedBytes, totalBufferedPages, totalRowsSent, totalPagesSent, ImmutableList.of(), Optional.empty()); + } + + public OutputBufferInfo summarizeFinal() + { + return new OutputBufferInfo(type, state, canAddBuffers, canAddPages, totalBufferedBytes, totalBufferedPages, totalRowsSent, totalPagesSent, ImmutableList.of(), utilization); } @Override @@ -134,13 +150,14 @@ public boolean equals(Object o) Objects.equals(totalRowsSent, that.totalRowsSent) && Objects.equals(totalPagesSent, that.totalPagesSent) && state == that.state && - Objects.equals(buffers, that.buffers); + Objects.equals(buffers, that.buffers) && + Objects.equals(utilization, that.utilization); } @Override public int hashCode() { - return Objects.hash(state, canAddBuffers, canAddPages, totalBufferedBytes, totalBufferedPages, totalRowsSent, totalPagesSent, buffers); + return Objects.hash(state, canAddBuffers, canAddPages, totalBufferedBytes, totalBufferedPages, totalRowsSent, totalPagesSent, buffers, utilization); } @Override @@ -156,6 +173,7 @@ public String toString() .add("totalRowsSent", totalRowsSent) .add("totalPagesSent", totalPagesSent) .add("buffers", buffers) + .add("bufferUtilization", utilization) .toString(); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferMemoryManager.java b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferMemoryManager.java index b805b86780e2..4c8c0801bc25 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferMemoryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferMemoryManager.java @@ -15,8 +15,10 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Suppliers; +import com.google.common.base.Ticker; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import io.airlift.stats.TDigest; import io.trino.memory.context.LocalMemoryContext; import javax.annotation.Nullable; @@ -54,11 +56,20 @@ class OutputBufferMemoryManager @GuardedBy("this") private ListenableFuture blockedOnMemory = NOT_BLOCKED; + private final Ticker ticker = Ticker.systemTicker(); + private final AtomicBoolean blockOnFull = new AtomicBoolean(true); private final Supplier memoryContextSupplier; private final Executor notificationExecutor; + @GuardedBy("this") + private final TDigest bufferUtilization = new TDigest(); + @GuardedBy("this") + private long lastBufferUtilizationRecordTime; + @GuardedBy("this") + private double lastBufferUtilization; + public OutputBufferMemoryManager(long maxBufferedBytes, Supplier memoryContextSupplier, Executor notificationExecutor) { requireNonNull(memoryContextSupplier, "memoryContextSupplier is null"); @@ -66,6 +77,8 @@ public OutputBufferMemoryManager(long maxBufferedBytes, Supplier getBufferBlockedFuture() { if (bufferBlockedFuture == null) { @@ -155,6 +177,13 @@ public double getUtilization() return bufferedBytes.get() / (double) maxBufferedBytes; } + public synchronized TDigest getUtilizationHistogram() + { + // always get most up to date histogram + recordBufferUtilization(); + return TDigest.copyOf(bufferUtilization); + } + public boolean isOverutilized() { return isBufferFull(); diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java index 6d45924364f0..7f0d2c199b21 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java @@ -130,7 +130,7 @@ protected Page computeNext() context.close(); // Release context buffers return endOfData(); } - else if (read != headerBuffer.length) { + if (read != headerBuffer.length) { throw new EOFException(); } @@ -167,7 +167,7 @@ protected Slice computeNext() if (read <= 0) { return endOfData(); } - else if (read != headerBuffer.length) { + if (read != headerBuffer.length) { throw new EOFException(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java index b5b3677d8735..0a7dfcac1d09 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java @@ -22,6 +22,7 @@ import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.buffer.SerializedPageReference.PagesReleasedListener; import io.trino.memory.context.LocalMemoryContext; +import io.trino.plugin.base.metrics.TDigestHistogram; import java.util.List; import java.util.Optional; @@ -127,7 +128,8 @@ public OutputBufferInfo getInfo() totalBufferedPages, totalRowsAdded.get(), totalPagesAdded.get(), - infos.build()); + infos.build(), + Optional.of(new TDigestHistogram(memoryManager.getUtilizationHistogram()))); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java index 2d5f0bb7146e..e52860fcc7dc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java @@ -86,7 +86,8 @@ public OutputBufferInfo getInfo() totalPagesAdded.get(), totalRowsAdded.get(), totalPagesAdded.get(), - ImmutableList.of()); + ImmutableList.of(), + Optional.empty()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/resourcegroups/InternalResourceGroup.java b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/InternalResourceGroup.java index 80bd6ab34857..05813cc68c5e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/resourcegroups/InternalResourceGroup.java +++ b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/InternalResourceGroup.java @@ -234,12 +234,10 @@ private ResourceGroupState getState() if (canRunMore()) { return CAN_RUN; } - else if (canQueueMore()) { + if (canQueueMore()) { return CAN_QUEUE; } - else { - return FULL; - } + return FULL; } } @@ -877,9 +875,7 @@ private static long getSubGroupSchedulingPriority(SchedulingPolicy policy, Inter if (policy == QUERY_PRIORITY) { return group.getHighestQueryPriority(); } - else { - return group.computeSchedulingWeight(); - } + return group.computeSchedulingWeight(); } private long computeSchedulingWeight() diff --git a/core/trino-main/src/main/java/io/trino/execution/resourcegroups/StochasticPriorityQueue.java b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/StochasticPriorityQueue.java index ae6c92b90e26..81125d0f31d1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/resourcegroups/StochasticPriorityQueue.java +++ b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/StochasticPriorityQueue.java @@ -257,9 +257,7 @@ public Node addNode(E value, long tickets) if (left.get().descendants < right.get().descendants) { return left.get().addNode(value, tickets); } - else { - return right.get().addNode(value, tickets); - } + return right.get().addNode(value, tickets); } Node child = new Node<>(Optional.of(this), value); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index 1a6b1542e9ed..5610763bb533 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -148,10 +148,8 @@ public ScheduleResult schedule() if (blockedReason != null) { return new ScheduleResult(sourceSchedulers.isEmpty(), newTasks, blocked, blockedReason, splitsScheduled); } - else { - checkState(blocked.isDone(), "blockedReason not provided when scheduler is blocked"); - return new ScheduleResult(sourceSchedulers.isEmpty(), newTasks, splitsScheduled); - } + checkState(blocked.isDone(), "blockedReason not provided when scheduler is blocked"); + return new ScheduleResult(sourceSchedulers.isEmpty(), newTasks, splitsScheduled); } @Override diff --git a/core/trino-main/src/main/java/io/trino/json/PathEvaluationVisitor.java b/core/trino-main/src/main/java/io/trino/json/PathEvaluationVisitor.java index f197211eab0c..b82029b72062 100644 --- a/core/trino-main/src/main/java/io/trino/json/PathEvaluationVisitor.java +++ b/core/trino-main/src/main/java/io/trino/json/PathEvaluationVisitor.java @@ -528,47 +528,45 @@ private static long asArrayIndex(Object object) } return jsonNode.longValue(); } - else { - TypedValue value = (TypedValue) object; - Type type = value.getType(); - if (type.equals(BIGINT) || type.equals(INTEGER) || type.equals(SMALLINT) || type.equals(TINYINT)) { - return value.getLongValue(); + TypedValue value = (TypedValue) object; + Type type = value.getType(); + if (type.equals(BIGINT) || type.equals(INTEGER) || type.equals(SMALLINT) || type.equals(TINYINT)) { + return value.getLongValue(); + } + if (type.equals(DOUBLE)) { + try { + return DoubleOperators.castToLong(value.getDoubleValue()); } - if (type.equals(DOUBLE)) { - try { - return DoubleOperators.castToLong(value.getDoubleValue()); - } - catch (Exception e) { - throw new PathEvaluationError(e); - } + catch (Exception e) { + throw new PathEvaluationError(e); } - if (type.equals(REAL)) { - try { - return RealOperators.castToLong(value.getLongValue()); - } - catch (Exception e) { - throw new PathEvaluationError(e); - } + } + if (type.equals(REAL)) { + try { + return RealOperators.castToLong(value.getLongValue()); } - if (type instanceof DecimalType) { - DecimalType decimalType = (DecimalType) type; - int precision = decimalType.getPrecision(); - int scale = decimalType.getScale(); - if (((DecimalType) type).isShort()) { - long tenToScale = longTenToNth(DecimalConversions.intScale(scale)); - return DecimalCasts.shortDecimalToBigint(value.getLongValue(), precision, scale, tenToScale); - } - Int128 tenToScale = Int128Math.powerOfTen(DecimalConversions.intScale(scale)); - try { - return DecimalCasts.longDecimalToBigint((Int128) value.getObjectValue(), precision, scale, tenToScale); - } - catch (Exception e) { - throw new PathEvaluationError(e); - } + catch (Exception e) { + throw new PathEvaluationError(e); + } + } + if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + int precision = decimalType.getPrecision(); + int scale = decimalType.getScale(); + if (((DecimalType) type).isShort()) { + long tenToScale = longTenToNth(DecimalConversions.intScale(scale)); + return DecimalCasts.shortDecimalToBigint(value.getLongValue(), precision, scale, tenToScale); + } + Int128 tenToScale = Int128Math.powerOfTen(DecimalConversions.intScale(scale)); + try { + return DecimalCasts.longDecimalToBigint((Int128) value.getObjectValue(), precision, scale, tenToScale); + } + catch (Exception e) { + throw new PathEvaluationError(e); } - - throw itemTypeError("NUMBER", type.getDisplayName()); } + + throw itemTypeError("NUMBER", type.getDisplayName()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java b/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java index 99d7bea38572..f519582e934c 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java @@ -113,9 +113,7 @@ public static Type typeForMagicLiteral(Type type) if (type instanceof VarcharType) { return type; } - else { - return VARBINARY; - } + return VARBINARY; } if (clazz == boolean.class) { return BOOLEAN; diff --git a/core/trino-main/src/main/java/io/trino/metadata/QualifiedTablePrefix.java b/core/trino-main/src/main/java/io/trino/metadata/QualifiedTablePrefix.java index e500720ac777..3666e85c082b 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/QualifiedTablePrefix.java +++ b/core/trino-main/src/main/java/io/trino/metadata/QualifiedTablePrefix.java @@ -100,12 +100,10 @@ public SchemaTablePrefix asSchemaTablePrefix() if (schemaName.isEmpty()) { return new SchemaTablePrefix(); } - else if (tableName.isEmpty()) { + if (tableName.isEmpty()) { return new SchemaTablePrefix(schemaName.get()); } - else { - return new SchemaTablePrefix(schemaName.get(), tableName.get()); - } + return new SchemaTablePrefix(schemaName.get(), tableName.get()); } public Optional asQualifiedObjectName() diff --git a/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java b/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java index 748289c8572c..4c05c47d1451 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java @@ -479,11 +479,9 @@ private boolean appendConstraintSolvers( ImmutableList.Builder formalTypeParameterTypeSignatures = ImmutableList.builder(); for (TypeSignatureParameter formalTypeParameter : formalTypeSignature.getParameters()) { - Optional typeSignature = formalTypeParameter.getTypeSignatureOrNamedTypeSignature(); - if (typeSignature.isEmpty()) { - throw new UnsupportedOperationException("Types with both type parameters and literal parameters at the same time are not supported"); - } - formalTypeParameterTypeSignatures.add(typeSignature.get()); + TypeSignature typeSignature = formalTypeParameter.getTypeSignatureOrNamedTypeSignature() + .orElseThrow(() -> new UnsupportedOperationException("Types with both type parameters and literal parameters at the same time are not supported")); + formalTypeParameterTypeSignatures.add(typeSignature); } return appendConstraintSolvers( @@ -692,13 +690,11 @@ private boolean canCast(Type fromType, Type toType) } return true; } - else if (toType instanceof JsonType) { + if (toType instanceof JsonType) { return fromType.getTypeParameters().stream() .allMatch(fromTypeParameter -> canCast(fromTypeParameter, toType)); } - else { - return false; - } + return false; } if (fromType instanceof JsonType) { if (toType instanceof RowType) { diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index 32415aa1d3f2..f5e461cc3e5d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -45,6 +45,7 @@ import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Sets.newConcurrentHashSet; import static java.util.Objects.requireNonNull; @@ -155,10 +156,7 @@ public synchronized void addLocation(TaskId taskId, URI location) return; } - // ignore duplicate locations - if (allClients.containsKey(location)) { - return; - } + checkArgument(!allClients.containsKey(location), "location already exist: %s", location); checkState(!noMoreLocations, "No more locations already set"); buffer.addTask(taskId); diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java index 169bdc534fe4..cd8ff2bcc31b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java @@ -235,9 +235,7 @@ public CounterStat getInputDataSize() if (inputOperator != null) { return inputOperator.getInputDataSize(); } - else { - return new CounterStat(); - } + return new CounterStat(); } public CounterStat getInputPositions() @@ -246,9 +244,7 @@ public CounterStat getInputPositions() if (inputOperator != null) { return inputOperator.getInputPositions(); } - else { - return new CounterStat(); - } + return new CounterStat(); } public CounterStat getOutputDataSize() @@ -257,9 +253,7 @@ public CounterStat getOutputDataSize() if (inputOperator != null) { return inputOperator.getOutputDataSize(); } - else { - return new CounterStat(); - } + return new CounterStat(); } public CounterStat getOutputPositions() @@ -268,9 +262,7 @@ public CounterStat getOutputPositions() if (inputOperator != null) { return inputOperator.getOutputPositions(); } - else { - return new CounterStat(); - } + return new CounterStat(); } public long getPhysicalWrittenDataSize() diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java index 812e5535f325..ba86b1998612 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.sql.planner.plan.PlanNodeId; +import javax.annotation.concurrent.GuardedBy; + import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -34,7 +36,8 @@ public class DriverFactory private final Optional sourceId; private final OptionalInt driverInstances; - private boolean closed; + @GuardedBy("this") + private boolean noMoreDrivers; public DriverFactory(int pipelineId, boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances) { @@ -91,7 +94,7 @@ public List getOperatorFactories() public synchronized Driver createDriver(DriverContext driverContext) { - checkState(!closed, "DriverFactory is already closed"); + checkState(!noMoreDrivers, "noMoreDrivers is already set"); requireNonNull(driverContext, "driverContext is null"); ImmutableList.Builder operators = ImmutableList.builder(); for (OperatorFactory operatorFactory : operatorFactories) { @@ -103,12 +106,17 @@ public synchronized Driver createDriver(DriverContext driverContext) public synchronized void noMoreDrivers() { - if (closed) { + if (noMoreDrivers) { return; } - closed = true; + noMoreDrivers = true; for (OperatorFactory operatorFactory : operatorFactories) { operatorFactory.noMoreOperators(); } } + + public synchronized boolean isNoMoreDrivers() + { + return noMoreDrivers; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java index 3d1cb6ef8911..d84b3b0c4c48 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java @@ -29,6 +29,10 @@ import io.trino.spi.exchange.ExchangeId; import io.trino.split.RemoteSplit; import io.trino.sql.planner.plan.PlanNodeId; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; + +import javax.annotation.concurrent.ThreadSafe; import java.util.Optional; import java.util.function.Supplier; @@ -56,6 +60,9 @@ public static class ExchangeOperatorFactory private ExchangeDataSource exchangeDataSource; private boolean closed; + private final NoMoreSplitsTracker noMoreSplitsTracker = new NoMoreSplitsTracker(); + private int nextOperatorInstanceId; + public ExchangeOperatorFactory( int operatorId, PlanNodeId sourceId, @@ -99,16 +106,28 @@ public SourceOperator createOperator(DriverContext driverContext) retryPolicy, exchangeManagerRegistry); } - return new ExchangeOperator( + int operatorInstanceId = nextOperatorInstanceId; + nextOperatorInstanceId++; + ExchangeOperator exchangeOperator = new ExchangeOperator( operatorContext, sourceId, exchangeDataSource, - serdeFactory.createPagesSerde()); + serdeFactory.createPagesSerde(), + noMoreSplitsTracker, + operatorInstanceId); + noMoreSplitsTracker.operatorAdded(operatorInstanceId); + return exchangeOperator; } @Override public void noMoreOperators() { + noMoreSplitsTracker.noMoreOperators(); + if (noMoreSplitsTracker.isNoMoreSplits()) { + if (exchangeDataSource != null) { + exchangeDataSource.noMoreInputs(); + } + } closed = true; } } @@ -117,18 +136,25 @@ public void noMoreOperators() private final PlanNodeId sourceId; private final ExchangeDataSource exchangeDataSource; private final PagesSerde serde; + private final NoMoreSplitsTracker noMoreSplitsTracker; + private final int operatorInstanceId; + private ListenableFuture isBlocked = NOT_BLOCKED; public ExchangeOperator( OperatorContext operatorContext, PlanNodeId sourceId, ExchangeDataSource exchangeDataSource, - PagesSerde serde) + PagesSerde serde, + NoMoreSplitsTracker noMoreSplitsTracker, + int operatorInstanceId) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.sourceId = requireNonNull(sourceId, "sourceId is null"); this.exchangeDataSource = requireNonNull(exchangeDataSource, "exchangeDataSource is null"); this.serde = requireNonNull(serde, "serde is null"); + this.noMoreSplitsTracker = requireNonNull(noMoreSplitsTracker, "noMoreSplitsTracker is null"); + this.operatorInstanceId = operatorInstanceId; operatorContext.setInfoSupplier(exchangeDataSource::getInfo); } @@ -154,7 +180,10 @@ public Supplier> addSplit(Split split) @Override public void noMoreSplits() { - exchangeDataSource.noMoreInputs(); + noMoreSplitsTracker.noMoreSplits(operatorInstanceId); + if (noMoreSplitsTracker.isNoMoreSplits()) { + exchangeDataSource.noMoreInputs(); + } } @Override @@ -220,4 +249,34 @@ public void close() { exchangeDataSource.close(); } + + @ThreadSafe + private static class NoMoreSplitsTracker + { + private final IntSet allOperators = new IntOpenHashSet(); + private final IntSet noMoreSplitsOperators = new IntOpenHashSet(); + private boolean noMoreOperators; + + public synchronized void operatorAdded(int operatorInstanceId) + { + checkState(!noMoreOperators, "noMoreOperators is set"); + allOperators.add(operatorInstanceId); + } + + public synchronized void noMoreOperators() + { + noMoreOperators = true; + } + + public synchronized void noMoreSplits(int operatorInstanceId) + { + checkState(allOperators.contains(operatorInstanceId), "operatorInstanceId not found: %s", operatorInstanceId); + noMoreSplitsOperators.add(operatorInstanceId); + } + + public synchronized boolean isNoMoreSplits() + { + return noMoreOperators && noMoreSplitsOperators.containsAll(allOperators); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRankAccumulator.java b/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRankAccumulator.java index 52070cfdb0ae..f41c1f52b8ab 100644 --- a/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRankAccumulator.java +++ b/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRankAccumulator.java @@ -119,7 +119,7 @@ public boolean add(long groupId, RowReference rowReference) heapInsert(groupId, newPeerGroupIndex, 1); return true; } - else if (rowReference.compareTo(strategy, peekRootRowId(groupId)) < 0) { + if (rowReference.compareTo(strategy, peekRootRowId(groupId)) < 0) { // Given that total number of values >= topN, we can only consider values that are less than the root (otherwise topN would be violated) long newPeerGroupIndex = peerGroupBuffer.allocateNewNode(rowReference.allocateRowId(), UNKNOWN_INDEX); // Rank will increase by +1 after insertion, so only need to pop if root rank is already == topN. @@ -131,10 +131,8 @@ else if (rowReference.compareTo(strategy, peekRootRowId(groupId)) < 0) { } return true; } - else { - // Row cannot be accepted because the total number of values >= topN, and the row is greater than the root (meaning it's rank would be at least topN+1). - return false; - } + // Row cannot be accepted because the total number of values >= topN, and the row is greater than the root (meaning it's rank would be at least topN+1). + return false; } /** diff --git a/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRowNumberAccumulator.java b/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRowNumberAccumulator.java index a8f60b3578e3..866891aa1fac 100644 --- a/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRowNumberAccumulator.java +++ b/core/trino-main/src/main/java/io/trino/operator/GroupedTopNRowNumberAccumulator.java @@ -88,13 +88,11 @@ public boolean add(long groupId, RowReference rowReference) heapInsert(groupId, rowReference.allocateRowId()); return true; } - else if (rowReference.compareTo(strategy, heapNodeBuffer.getRowId(heapRootNodeIndex)) < 0) { + if (rowReference.compareTo(strategy, heapNodeBuffer.getRowId(heapRootNodeIndex)) < 0) { heapPopAndInsert(groupId, rowReference.allocateRowId(), rowIdEvictionListener); return true; } - else { - return false; - } + return false; } /** diff --git a/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java b/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java index a500a160e102..291fdf67fc79 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java @@ -369,12 +369,10 @@ public boolean needsInput() if (finishing || outputPages != null) { return false; } - else if (aggregationBuilder != null && aggregationBuilder.isFull()) { + if (aggregationBuilder != null && aggregationBuilder.isFull()) { return false; } - else { - return unfinishedWork == null; - } + return unfinishedWork == null; } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java b/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java index f4537718e6fe..1876c2edb5fd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java @@ -14,8 +14,12 @@ package io.trino.operator; import io.trino.operator.join.JoinBridgeManager; +import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; +import io.trino.operator.join.LookupSourceFactory; +import io.trino.operator.join.unspilled.PartitionedLookupSourceFactory; import io.trino.spi.type.Type; import io.trino.spiller.PartitioningSpillerFactory; +import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.type.BlockTypeOperators; @@ -23,31 +27,32 @@ import java.util.Optional; import java.util.OptionalInt; +import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.FULL_OUTER; +import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.INNER; +import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.LOOKUP_OUTER; +import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.PROBE_OUTER; +import static java.util.Objects.requireNonNull; + public interface OperatorFactories { - OperatorFactory innerJoin( + public OperatorFactory join( + JoinOperatorType joinType, int operatorId, PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean outputSingleMatch, - boolean waitForBuild, + JoinBridgeManager lookupSourceFactory, boolean hasFilter, - boolean spillingEnabled, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, Optional> probeOutputChannels, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, BlockTypeOperators blockTypeOperators); - OperatorFactory probeOuterJoin( + public OperatorFactory spillingJoin( + JoinOperatorType joinType, int operatorId, PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean outputSingleMatch, + JoinBridgeManager lookupSourceFactory, boolean hasFilter, - boolean spillingEnabled, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, @@ -56,32 +61,62 @@ OperatorFactory probeOuterJoin( PartitioningSpillerFactory partitioningSpillerFactory, BlockTypeOperators blockTypeOperators); - OperatorFactory lookupOuterJoin( - int operatorId, - PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean waitForBuild, - boolean hasFilter, - boolean spillingEnabled, - List probeTypes, - List probeJoinChannel, - OptionalInt probeHashChannel, - Optional> probeOutputChannels, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, - BlockTypeOperators blockTypeOperators); + class JoinOperatorType + { + private final JoinType type; + private final boolean outputSingleMatch; + private final boolean waitForBuild; - OperatorFactory fullOuterJoin( - int operatorId, - PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean hasFilter, - boolean spillingEnabled, - List probeTypes, - List probeJoinChannel, - OptionalInt probeHashChannel, - Optional> probeOutputChannels, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, - BlockTypeOperators blockTypeOperators); + public static JoinOperatorType ofJoinNodeType(JoinNode.Type joinNodeType, boolean outputSingleMatch, boolean waitForBuild) + { + return switch (joinNodeType) { + case INNER -> innerJoin(outputSingleMatch, waitForBuild); + case LEFT -> probeOuterJoin(outputSingleMatch); + case RIGHT -> lookupOuterJoin(waitForBuild); + case FULL -> fullOuterJoin(); + }; + } + + public static JoinOperatorType innerJoin(boolean outputSingleMatch, boolean waitForBuild) + { + return new JoinOperatorType(INNER, outputSingleMatch, waitForBuild); + } + + public static JoinOperatorType probeOuterJoin(boolean outputSingleMatch) + { + return new JoinOperatorType(PROBE_OUTER, outputSingleMatch, false); + } + + public static JoinOperatorType lookupOuterJoin(boolean waitForBuild) + { + return new JoinOperatorType(LOOKUP_OUTER, false, waitForBuild); + } + + public static JoinOperatorType fullOuterJoin() + { + return new JoinOperatorType(FULL_OUTER, false, false); + } + + private JoinOperatorType(JoinType type, boolean outputSingleMatch, boolean waitForBuild) + { + this.type = requireNonNull(type, "type is null"); + this.outputSingleMatch = outputSingleMatch; + this.waitForBuild = waitForBuild; + } + + public boolean isOutputSingleMatch() + { + return outputSingleMatch; + } + + public boolean isWaitForBuild() + { + return waitForBuild; + } + + public JoinType getType() + { + return type; + } + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/RowReferencePageManager.java b/core/trino-main/src/main/java/io/trino/operator/RowReferencePageManager.java index cde894270e08..0e58d279b46a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/RowReferencePageManager.java +++ b/core/trino-main/src/main/java/io/trino/operator/RowReferencePageManager.java @@ -98,10 +98,8 @@ public Page getPage(long rowId) checkState(currentCursor != null, "No active cursor"); return currentCursor.getPage(); } - else { - int pageId = rowIdBuffer.getPageId(rowId); - return pages.get(pageId).getPage(); - } + int pageId = rowIdBuffer.getPageId(rowId); + return pages.get(pageId).getPage(); } public int getPosition(long rowId) @@ -111,9 +109,7 @@ public int getPosition(long rowId) // rowId for cursors only reference the single current position return currentCursor.getCurrentPosition(); } - else { - return rowIdBuffer.getPosition(rowId); - } + return rowIdBuffer.getPosition(rowId); } private static boolean isCursorRowId(long rowId) diff --git a/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java b/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java index 9225837dd7d8..bb8c5b210c8f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java @@ -272,10 +272,8 @@ public TransformationState> process(Split split) cursor = ((RecordPageSource) source).getCursor(); return ofResult(processColumnSource()); } - else { - pageSource = source; - return ofResult(processPageSource()); - } + pageSource = source; + return ofResult(processPageSource()); } WorkProcessor processColumnSource() @@ -356,14 +354,12 @@ public ProcessState process() outputMemoryContext.setBytes(pageBuilder.getRetainedSizeInBytes()); return ProcessState.ofResult(page); } - else if (finished) { + if (finished) { checkState(pageBuilder.isEmpty()); return ProcessState.finished(); } - else { - outputMemoryContext.setBytes(pageBuilder.getRetainedSizeInBytes()); - return ProcessState.yielded(); - } + outputMemoryContext.setBytes(pageBuilder.getRetainedSizeInBytes()); + return ProcessState.yielded(); } } @@ -396,9 +392,7 @@ public ProcessState process() if (pageSource.isFinished()) { return ProcessState.finished(); } - else { - return ProcessState.yielded(); - } + return ProcessState.yielded(); } recordMaterializedBytes(page, sizeInBytes -> processedBytes += sizeInBytes); diff --git a/core/trino-main/src/main/java/io/trino/operator/TableScanWorkProcessorOperator.java b/core/trino-main/src/main/java/io/trino/operator/TableScanWorkProcessorOperator.java index 2b80242bf056..5f5c5ee5e8b7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TableScanWorkProcessorOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TableScanWorkProcessorOperator.java @@ -291,9 +291,7 @@ public ProcessState process() if (pageSource.isFinished()) { return ProcessState.finished(); } - else { - return ProcessState.yielded(); - } + return ProcessState.yielded(); } return ProcessState.ofResult(page); diff --git a/core/trino-main/src/main/java/io/trino/operator/TopNRankingOperator.java b/core/trino-main/src/main/java/io/trino/operator/TopNRankingOperator.java index b39b624cd4b4..112ffb7dc4e9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TopNRankingOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TopNRankingOperator.java @@ -246,19 +246,17 @@ private static Supplier getGroupByHashSupplier( if (partitionChannels.isEmpty()) { return Suppliers.ofInstance(new NoChannelGroupByHash()); } - else { - checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); - int[] channels = Ints.toArray(partitionChannels); - return () -> createGroupByHash( - session, - partitionTypes, - channels, - hashChannel, - expectedPositions, - joinCompiler, - blockTypeOperators, - updateMemory); - } + checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); + int[] channels = Ints.toArray(partitionChannels); + return () -> createGroupByHash( + session, + partitionTypes, + channels, + hashChannel, + expectedPositions, + joinCompiler, + blockTypeOperators, + updateMemory); } private static Supplier getGroupedTopNBuilderSupplier( @@ -281,7 +279,7 @@ private static Supplier getGroupedTopNBuilderSupplier( generateRanking, groupByHashSupplier.get()); } - else if (rankingType == RankingType.RANK) { + if (rankingType == RankingType.RANK) { PageWithPositionComparator comparator = new SimplePageWithPositionComparator(sourceTypes, sortChannels, sortOrders, typeOperators); PageWithPositionEqualsAndHash equalsAndHash = new SimplePageWithPositionEqualsAndHash(ImmutableList.copyOf(sourceTypes), sortChannels, blockTypeOperators); return () -> new GroupedTopNRankBuilder( @@ -292,12 +290,10 @@ else if (rankingType == RankingType.RANK) { generateRanking, groupByHashSupplier.get()); } - else if (rankingType == RankingType.DENSE_RANK) { + if (rankingType == RankingType.DENSE_RANK) { throw new UnsupportedOperationException(); } - else { - throw new AssertionError("Unknown ranking type: " + rankingType); - } + throw new AssertionError("Unknown ranking type: " + rankingType); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java index a73d63b3a5d9..c154f035c18f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java +++ b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java @@ -16,8 +16,8 @@ import io.trino.operator.join.JoinBridgeManager; import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupJoinOperatorFactory; -import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSourceFactory; +import io.trino.operator.join.unspilled.PartitionedLookupSourceFactory; import io.trino.spi.type.Type; import io.trino.spiller.PartitioningSpillerFactory; import io.trino.sql.planner.plan.PlanNodeId; @@ -29,144 +29,76 @@ import java.util.stream.IntStream; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.FULL_OUTER; -import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.INNER; -import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.LOOKUP_OUTER; -import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.PROBE_OUTER; public class TrinoOperatorFactories implements OperatorFactories { @Override - public OperatorFactory innerJoin( + public OperatorFactory join( + JoinOperatorType joinType, int operatorId, PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean outputSingleMatch, - boolean waitForBuild, + JoinBridgeManager lookupSourceFactory, boolean hasFilter, - boolean spillingEnabled, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, - Optional> probeOutputChannels, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, + Optional> probeOutputChannelsOptional, BlockTypeOperators blockTypeOperators) { - return createJoinOperatorFactory( - operatorId, - planNodeId, - lookupSourceFactory, - probeTypes, - probeJoinChannel, - probeHashChannel, - probeOutputChannels.orElse(rangeList(probeTypes.size())), - INNER, - outputSingleMatch, - waitForBuild, - spillingEnabled, - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); - } + List probeOutputChannels = probeOutputChannelsOptional.orElse(rangeList(probeTypes.size())); + List probeOutputChannelTypes = probeOutputChannels.stream() + .map(probeTypes::get) + .collect(toImmutableList()); - @Override - public OperatorFactory probeOuterJoin( - int operatorId, - PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean outputSingleMatch, - boolean hasFilter, - boolean spillingEnabled, - List probeTypes, - List probeJoinChannel, - OptionalInt probeHashChannel, - Optional> probeOutputChannels, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, - BlockTypeOperators blockTypeOperators) - { - return createJoinOperatorFactory( + return new io.trino.operator.join.unspilled.LookupJoinOperatorFactory( operatorId, planNodeId, lookupSourceFactory, probeTypes, + probeOutputChannelTypes, + lookupSourceFactory.getBuildOutputTypes(), + joinType, + new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel), + blockTypeOperators, probeJoinChannel, - probeHashChannel, - probeOutputChannels.orElse(rangeList(probeTypes.size())), - PROBE_OUTER, - outputSingleMatch, - false, - spillingEnabled, - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); + probeHashChannel); } @Override - public OperatorFactory lookupOuterJoin( + public OperatorFactory spillingJoin( + JoinOperatorType joinType, int operatorId, PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean waitForBuild, + JoinBridgeManager lookupSourceFactory, boolean hasFilter, - boolean spillingEnabled, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, - Optional> probeOutputChannels, + Optional> probeOutputChannelsOptional, OptionalInt totalOperatorsCount, PartitioningSpillerFactory partitioningSpillerFactory, BlockTypeOperators blockTypeOperators) { - return createJoinOperatorFactory( - operatorId, - planNodeId, - lookupSourceFactory, - probeTypes, - probeJoinChannel, - probeHashChannel, - probeOutputChannels.orElse(rangeList(probeTypes.size())), - LOOKUP_OUTER, - false, - waitForBuild, - spillingEnabled, - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); - } + List probeOutputChannels = probeOutputChannelsOptional.orElse(rangeList(probeTypes.size())); + List probeOutputChannelTypes = probeOutputChannels.stream() + .map(probeTypes::get) + .collect(toImmutableList()); - @Override - public OperatorFactory fullOuterJoin( - int operatorId, - PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactory, - boolean hasFilter, - boolean spillingEnabled, - List probeTypes, - List probeJoinChannel, - OptionalInt probeHashChannel, - Optional> probeOutputChannels, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, - BlockTypeOperators blockTypeOperators) - { - return createJoinOperatorFactory( + return new LookupJoinOperatorFactory( operatorId, planNodeId, lookupSourceFactory, probeTypes, + probeOutputChannelTypes, + lookupSourceFactory.getBuildOutputTypes(), + joinType, + new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel), + blockTypeOperators, + totalOperatorsCount, probeJoinChannel, probeHashChannel, - probeOutputChannels.orElse(rangeList(probeTypes.size())), - FULL_OUTER, - false, - false, - spillingEnabled, - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); + partitioningSpillerFactory); } private static List rangeList(int endExclusive) @@ -175,59 +107,4 @@ private static List rangeList(int endExclusive) .boxed() .collect(toImmutableList()); } - - private OperatorFactory createJoinOperatorFactory( - int operatorId, - PlanNodeId planNodeId, - JoinBridgeManager lookupSourceFactoryManager, - List probeTypes, - List probeJoinChannel, - OptionalInt probeHashChannel, - List probeOutputChannels, - JoinType joinType, - boolean outputSingleMatch, - boolean waitForBuild, - boolean spillingEnabled, - OptionalInt totalOperatorsCount, - PartitioningSpillerFactory partitioningSpillerFactory, - BlockTypeOperators blockTypeOperators) - { - List probeOutputChannelTypes = probeOutputChannels.stream() - .map(probeTypes::get) - .collect(toImmutableList()); - - if (spillingEnabled) { - return new LookupJoinOperatorFactory( - operatorId, - planNodeId, - (JoinBridgeManager) lookupSourceFactoryManager, - probeTypes, - probeOutputChannelTypes, - lookupSourceFactoryManager.getBuildOutputTypes(), - joinType, - outputSingleMatch, - waitForBuild, - new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel), - blockTypeOperators, - totalOperatorsCount, - probeJoinChannel, - probeHashChannel, - partitioningSpillerFactory); - } - - return new io.trino.operator.join.unspilled.LookupJoinOperatorFactory( - operatorId, - planNodeId, - (JoinBridgeManager) lookupSourceFactoryManager, - probeTypes, - probeOutputChannelTypes, - lookupSourceFactoryManager.getBuildOutputTypes(), - joinType, - outputSingleMatch, - waitForBuild, - new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel), - blockTypeOperators, - probeJoinChannel, - probeHashChannel); - } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java index c746174bdcee..45543c607509 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java @@ -52,9 +52,7 @@ public Type getType() if (step.isOutputPartial()) { return intermediateType; } - else { - return finalType; - } + return finalType; } public void processPage(Page page) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java index fe8f36a7d429..8a958fd0f918 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java @@ -267,9 +267,7 @@ private static Page filter(Page page, Block mask) if (!mask.isNull(0) && BOOLEAN.getBoolean(mask, 0)) { return page; } - else { - return page.getPositions(new int[0], 0, 0); - } + return page.getPositions(new int[0], 0, 0); } boolean mayHaveNull = mask.mayHaveNull(); int[] ids = new int[positions]; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java index 05f43e2b38b3..9098f325d145 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java @@ -59,9 +59,7 @@ public Type getType() if (step.isOutputPartial()) { return intermediateType; } - else { - return finalType; - } + return finalType; } public void processPage(GroupByIdBlock groupIds, Page page) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java index 7e3973f4c63d..f74054653d61 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java @@ -399,18 +399,12 @@ static AggregationParameterKind getInputParameterKind(boolean isNullable, boolea if (isNullable) { return NULLABLE_BLOCK_INPUT_CHANNEL; } - else { - return BLOCK_INPUT_CHANNEL; - } + return BLOCK_INPUT_CHANNEL; } - else { - if (isNullable) { - throw new IllegalArgumentException(methodName + " contains a parameter with @NullablePosition that is not @BlockPosition"); - } - else { - return INPUT_CHANNEL; - } + if (isNullable) { + throw new IllegalArgumentException(methodName + " contains a parameter with @NullablePosition that is not @BlockPosition"); } + return INPUT_CHANNEL; } private static Annotation baseTypeAnnotation(Annotation[] annotations, String methodName) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java index 4959ebab017b..4bb4a734b81e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java @@ -99,7 +99,7 @@ public AggregationImplementation specialize(BoundSignature boundSignature) .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) .build(); } - else if (stateType.getJavaType() == double.class) { + if (stateType.getJavaType() == double.class) { return AggregationImplementation.builder() .inputFunction(normalizeInputMethod(boundSignature, inputType, DOUBLE_STATE_INPUT_FUNCTION)) .combineFunction(DOUBLE_STATE_COMBINE_FUNCTION) @@ -111,7 +111,7 @@ else if (stateType.getJavaType() == double.class) { .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) .build(); } - else if (stateType.getJavaType() == boolean.class) { + if (stateType.getJavaType() == boolean.class) { return AggregationImplementation.builder() .inputFunction(normalizeInputMethod(boundSignature, inputType, BOOLEAN_STATE_INPUT_FUNCTION)) .combineFunction(BOOLEAN_STATE_COMBINE_FUNCTION) @@ -123,12 +123,10 @@ else if (stateType.getJavaType() == boolean.class) { .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) .build(); } - else { - // State with Slice or Block as native container type is intentionally not supported yet, - // as it may result in excessive JVM memory usage of remembered set. - // See JDK-8017163. - throw new TrinoException(NOT_SUPPORTED, format("State type not supported for %s: %s", NAME, stateType.getDisplayName())); - } + // State with Slice or Block as native container type is intentionally not supported yet, + // as it may result in excessive JVM memory usage of remembered set. + // See JDK-8017163. + throw new TrinoException(NOT_SUPPORTED, format("State type not supported for %s: %s", NAME, stateType.getDisplayName())); } private static MethodHandle normalizeInputMethod(BoundSignature boundSignature, Type inputType, MethodHandle inputMethodHandle) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java index 8287c6bf3135..f7e915d61343 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java @@ -253,9 +253,7 @@ public boolean contains(Block block, int position) if (block.isNull(position)) { return containsNullElement; } - else { - return blockPositionByHash.getInt(getHashPositionOfElement(block, position)) != EMPTY_SLOT; - } + return blockPositionByHash.getInt(getHashPositionOfElement(block, position)) != EMPTY_SLOT; } /** diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index d4cd20cd0944..ea9b40ab054d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -138,17 +138,15 @@ public Work processPage(Page page) if (groupedAggregators.isEmpty()) { return groupByHash.addPage(page); } - else { - return new TransformWork<>( - groupByHash.getGroupIds(page), - groupByIdBlock -> { - for (GroupedAggregator groupedAggregator : groupedAggregators) { - groupedAggregator.processPage(groupByIdBlock, page); - } - // we do not need any output from TransformWork for this case - return null; - }); - } + return new TransformWork<>( + groupByHash.getGroupIds(page), + groupByIdBlock -> { + for (GroupedAggregator groupedAggregator : groupedAggregators) { + groupedAggregator.processPage(groupByIdBlock, page); + } + // we do not need any output from TransformWork for this case + return null; + }); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java index da7ad0f8461a..7d266fdd76c5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -152,9 +152,7 @@ private boolean hasPreviousSpillCompletedSuccessfully() getFutureValue(spillInProgress); return true; } - else { - return false; - } + return false; } @Override @@ -208,10 +206,8 @@ public WorkProcessor buildResult() if (shouldMergeWithMemory(getSizeInMemoryWhenUnspilling())) { return mergeFromDiskAndMemory(); } - else { - getFutureValue(spillToDisk()); - return mergeFromDisk(); - } + getFutureValue(spillToDisk()); + return mergeFromDisk(); } /** diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedTypedHistogram.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedTypedHistogram.java index 3991be7b9fb6..5a5c36a1936a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedTypedHistogram.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedTypedHistogram.java @@ -436,10 +436,8 @@ private boolean processEntry(long groupId, Block block, int position, long count addNewGroup(groupId, block, position, count); return true; } - else { - valueNode.add(count); - return false; - } + valueNode.add(count); + return false; } private void addNewGroup(long groupId, Block block, int position, long count) @@ -488,16 +486,14 @@ private BucketDataNode createBucketDataNode(long groupId, Block block, int posit if (nodePointer == EMPTY_BUCKET) { return new BucketDataNode(bucketId, new ValueNode(nextNodePointer), valueHash, valueAndGroupHash, nextNodePointer, true); } - else if (groupAndValueMatches(groupId, block, position, nodePointer, valuePositions.get(nodePointer))) { + if (groupAndValueMatches(groupId, block, position, nodePointer, valuePositions.get(nodePointer))) { // value match return new BucketDataNode(bucketId, new ValueNode(nodePointer), valueHash, valueAndGroupHash, nodePointer, false); } - else { - // keep looking - int probe = nextProbe(probeCount); - bucketId = nextBucketId(originalBucketId, mask, probe); - probeCount++; - } + // keep looking + int probe = nextProbe(probeCount); + bucketId = nextBucketId(originalBucketId, mask, probe); + probeCount++; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/ValueStore.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/ValueStore.java index 23873939bce2..cc4170ed0cd6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/ValueStore.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/ValueStore.java @@ -98,15 +98,13 @@ public int addAndGetPosition(Block block, int position, long valueHash) return valuePointer; } - else if (equalOperator.equal(block, position, values, valuePointer)) { + if (equalOperator.equal(block, position, values, valuePointer)) { // value at position return valuePointer; } - else { - int probe = nextProbe(probeCount); - bucketId = nextBucketId(originalBucketId, mask, probe); - probeCount++; - } + int probe = nextProbe(probeCount); + bucketId = nextBucketId(originalBucketId, mask, probe); + probeCount++; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java index 72b1f109ab72..1e55902ab293 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java @@ -1151,10 +1151,7 @@ public Class getType() Type getSqlType() { - if (sqlType.isEmpty()) { - throw new IllegalArgumentException("Unsupported type: " + type); - } - return sqlType.get(); + return sqlType.orElseThrow(() -> new IllegalArgumentException("Unsupported type: " + type)); } boolean isPrimitiveType() diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index a88c01732a8a..a9be3416e296 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -14,7 +14,6 @@ package io.trino.operator.exchange; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.airlift.slice.XxHash64; @@ -26,7 +25,6 @@ import io.trino.operator.PartitionFunction; import io.trino.operator.PrecomputedHashGenerator; import io.trino.spi.Page; -import io.trino.spi.connector.ConnectorBucketNodeMap; import io.trino.spi.type.Type; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PartitioningHandle; @@ -226,9 +224,7 @@ private static PartitionFunction createPartitionFunction( // The same bucket function (with the same bucket count) as for node // partitioning must be used. This way rows within a single bucket // will be being processed by single thread. - int bucketCount = nodePartitioningManager.getConnectorBucketNodeMap(session, partitioning) - .map(ConnectorBucketNodeMap::getBucketCount) - .orElseThrow(() -> new VerifyException("No bucket node map for partitioning: " + partitioning)); + int bucketCount = nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount(); int[] bucketToPartition = new int[bucketCount]; for (int bucket = 0; bucket < bucketCount; bucket++) { // mix the bucket bits so we don't use the same bucket number used to distribute between stages diff --git a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPageJoiner.java b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPageJoiner.java index f9f5508dc63d..b2dc1d9fb216 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPageJoiner.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPageJoiner.java @@ -384,9 +384,7 @@ private long getJoinPositionWithinPartition() if (joinPosition >= 0) { return lookupSourceProvider.withLease(lookupSourceLease -> lookupSourceLease.getLookupSource().joinPositionWithinPartition(joinPosition)); } - else { - return -1; - } + return -1; } private Page buildOutputPage() diff --git a/core/trino-main/src/main/java/io/trino/operator/join/LookupJoinOperatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/join/LookupJoinOperatorFactory.java index 0190e8c7d320..5d0f901c8173 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/LookupJoinOperatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/LookupJoinOperatorFactory.java @@ -19,6 +19,7 @@ import io.trino.operator.InterpretedHashGenerator; import io.trino.operator.Operator; import io.trino.operator.OperatorContext; +import io.trino.operator.OperatorFactories.JoinOperatorType; import io.trino.operator.OperatorFactory; import io.trino.operator.PrecomputedHashGenerator; import io.trino.operator.ProcessorContext; @@ -80,9 +81,7 @@ public LookupJoinOperatorFactory( List probeTypes, List probeOutputTypes, List buildOutputTypes, - JoinType joinType, - boolean outputSingleMatch, - boolean waitForBuild, + JoinOperatorType joinOperatorType, JoinProbeFactory joinProbeFactory, BlockTypeOperators blockTypeOperators, OptionalInt totalOperatorsCount, @@ -94,9 +93,9 @@ public LookupJoinOperatorFactory( this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.probeTypes = ImmutableList.copyOf(requireNonNull(probeTypes, "probeTypes is null")); this.buildOutputTypes = ImmutableList.copyOf(requireNonNull(buildOutputTypes, "buildOutputTypes is null")); - this.joinType = requireNonNull(joinType, "joinType is null"); - this.outputSingleMatch = outputSingleMatch; - this.waitForBuild = waitForBuild; + this.joinType = requireNonNull(joinOperatorType.getType(), "joinType is null"); + this.outputSingleMatch = joinOperatorType.isOutputSingleMatch(); + this.waitForBuild = joinOperatorType.isWaitForBuild(); this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null"); this.joinBridgeManager = lookupSourceFactoryManager; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java index a26da4b8e25e..f70775367d64 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java @@ -267,15 +267,13 @@ static NestedLoopOutputIterator createNestedLoopOutputIterator(Page probePage, P Page outputPage = new Page(max(probePositions, buildPositions)); return new PageRepeatingIterator(outputPage, min(probePositions, buildPositions)); } - else if (probeChannels.length == 0 && probePage.getPositionCount() <= buildPage.getPositionCount()) { + if (probeChannels.length == 0 && probePage.getPositionCount() <= buildPage.getPositionCount()) { return new PageRepeatingIterator(buildPage.getColumns(buildChannels), probePage.getPositionCount()); } - else if (buildChannels.length == 0 && buildPage.getPositionCount() <= probePage.getPositionCount()) { + if (buildChannels.length == 0 && buildPage.getPositionCount() <= probePage.getPositionCount()) { return new PageRepeatingIterator(probePage.getColumns(probeChannels), buildPage.getPositionCount()); } - else { - return new NestedLoopPageBuilder(probePage, buildPage, probeChannels, buildChannels); - } + return new NestedLoopPageBuilder(probePage, buildPage, probeChannels, buildChannels); } // bi-morphic parent class for the two implementations allowed. Adding a third implementation will make getOutput megamorphic and diff --git a/core/trino-main/src/main/java/io/trino/operator/join/PartitionedConsumption.java b/core/trino-main/src/main/java/io/trino/operator/join/PartitionedConsumption.java index 5752a33ec60a..4047d9095d16 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/PartitionedConsumption.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/PartitionedConsumption.java @@ -106,9 +106,7 @@ protected Partition computeNext() if (next != null) { return next; } - else { - return endOfData(); - } + return endOfData(); } }; } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/PartitionedLookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/PartitionedLookupSource.java index b0c408833e98..e53f073b868c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/PartitionedLookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/PartitionedLookupSource.java @@ -69,16 +69,14 @@ public OuterPositionIterator getOuterPositionIterator() } }; } - else { - return TrackingLookupSourceSupplier.nonTracking( - () -> new PartitionedLookupSource( - partitions.stream() - .map(Supplier::get) - .collect(toImmutableList()), - hashChannelTypes, - Optional.empty(), - blockTypeOperators)); - } + return TrackingLookupSourceSupplier.nonTracking( + () -> new PartitionedLookupSource( + partitions.stream() + .map(Supplier::get) + .collect(toImmutableList()), + hashChannelTypes, + Optional.empty(), + blockTypeOperators)); } private final LookupSource[] lookupSources; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java b/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java index c57ef8df722d..1c55fd6bf997 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/SortedPositionLinks.java @@ -84,16 +84,14 @@ public int link(int from, int to) positionLinks.computeIfAbsent(to, key -> new IntArrayList()).add(from); return to; } - else { - // _to_ is larger so, move the chain to _from_ - IntArrayList links = positionLinks.remove(to); - if (links == null) { - links = new IntArrayList(); - } - links.add(to); - checkState(positionLinks.put(from, links) == null, "sorted links is corrupted"); - return from; + // _to_ is larger so, move the chain to _from_ + IntArrayList links = positionLinks.remove(to); + if (links == null) { + links = new IntArrayList(); } + links.add(to); + checkState(positionLinks.put(from, links) == null, "sorted links is corrupted"); + return from; } private boolean isNull(int position) diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java index 881593987516..1352737f1a63 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java @@ -19,6 +19,7 @@ import io.trino.operator.InterpretedHashGenerator; import io.trino.operator.Operator; import io.trino.operator.OperatorContext; +import io.trino.operator.OperatorFactories.JoinOperatorType; import io.trino.operator.OperatorFactory; import io.trino.operator.PrecomputedHashGenerator; import io.trino.operator.ProcessorContext; @@ -72,9 +73,7 @@ public LookupJoinOperatorFactory( List probeTypes, List probeOutputTypes, List buildOutputTypes, - JoinType joinType, - boolean outputSingleMatch, - boolean waitForBuild, + JoinOperatorType joinOperatorType, JoinProbeFactory joinProbeFactory, BlockTypeOperators blockTypeOperators, List probeJoinChannels, @@ -84,9 +83,9 @@ public LookupJoinOperatorFactory( this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.probeTypes = ImmutableList.copyOf(requireNonNull(probeTypes, "probeTypes is null")); this.buildOutputTypes = ImmutableList.copyOf(requireNonNull(buildOutputTypes, "buildOutputTypes is null")); - this.joinType = requireNonNull(joinType, "joinType is null"); - this.outputSingleMatch = outputSingleMatch; - this.waitForBuild = waitForBuild; + this.joinType = requireNonNull(joinOperatorType.getType(), "joinType is null"); + this.outputSingleMatch = joinOperatorType.isOutputSingleMatch(); + this.waitForBuild = joinOperatorType.isWaitForBuild(); this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null"); this.joinBridgeManager = lookupSourceFactoryManager; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java index 9cbf2bffbc91..a701de4c7049 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java @@ -72,16 +72,14 @@ public OuterPositionIterator getOuterPositionIterator() } }; } - else { - return TrackingLookupSourceSupplier.nonTracking( - () -> new PartitionedLookupSource( - partitions.stream() - .map(Supplier::get) - .collect(toImmutableList()), - hashChannelTypes, - Optional.empty(), - blockTypeOperators)); - } + return TrackingLookupSourceSupplier.nonTracking( + () -> new PartitionedLookupSource( + partitions.stream() + .map(Supplier::get) + .collect(toImmutableList()), + hashChannelTypes, + Optional.empty(), + blockTypeOperators)); } private final LookupSource[] lookupSources; diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java index c22ffca477eb..98cfda83d179 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java @@ -13,6 +13,7 @@ */ package io.trino.operator.output; +import com.google.common.annotations.VisibleForTesting; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.spi.block.Block; @@ -80,38 +81,43 @@ public void append(IntArrayList positions, Block block) return; } ensurePositionCapacity(positionCount + positions.size()); - int[] positionArray = positions.elements(); - int newByteCount = 0; - int[] lengths = new int[positions.size()]; - - if (block.mayHaveNull()) { - for (int i = 0; i < positions.size(); i++) { - int position = positionArray[i]; - if (block.isNull(position)) { - offsets[positionCount + i + 1] = offsets[positionCount + i]; - valueIsNull[positionCount + i] = true; - hasNullValue = true; + if (block instanceof VariableWidthBlock) { + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; + int newByteCount = 0; + int[] lengths = new int[positions.size()]; + int[] sourceOffsets = new int[positions.size()]; + int[] positionArray = positions.elements(); + + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + int length = variableWidthBlock.getSliceLength(position); + lengths[i] = length; + sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); + newByteCount += length; + boolean isNull = block.isNull(position); + valueIsNull[positionCount + i] = isNull; + offsets[positionCount + i + 1] = offsets[positionCount + i] + length; + hasNullValue |= isNull; + hasNonNullValue |= !isNull; } - else { - int length = block.getSliceLength(position); + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + int length = variableWidthBlock.getSliceLength(position); lengths[i] = length; + sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); newByteCount += length; offsets[positionCount + i + 1] = offsets[positionCount + i] + length; - hasNonNullValue = true; } + hasNonNullValue = true; } + copyBytes(variableWidthBlock.getRawSlice(), lengths, sourceOffsets, positions.size(), newByteCount); } else { - for (int i = 0; i < positions.size(); i++) { - int position = positionArray[i]; - int length = block.getSliceLength(position); - lengths[i] = length; - newByteCount += length; - offsets[positionCount + i + 1] = offsets[positionCount + i] + length; - } - hasNonNullValue = true; + appendGenericBlock(positions, block); } - copyBytes(block, lengths, positionArray, positions.size(), offsets, positionCount, newByteCount); } @Override @@ -132,7 +138,7 @@ public void appendRle(RunLengthEncodedBlock block) } else { hasNonNullValue = true; - duplicateBytes(block.getValue(), 0, rlePositionCount); + duplicateBytes(block.getSlice(0, 0, block.getSliceLength(0)), rlePositionCount); } } @@ -166,16 +172,20 @@ public long getSizeInBytes() return sizeInBytes; } - private void copyBytes(Block block, int[] lengths, int[] positions, int count, int[] targetOffsets, int targetOffsetsIndex, int newByteCount) + private void copyBytes(Slice rawSlice, int[] lengths, int[] sourceOffsets, int count, int newByteCount) { - ensureBytesCapacity(getCurrentOffset() + newByteCount); + ensureExtraBytesCapacity(newByteCount); - for (int i = 0; i < count; i++) { - int position = positions[i]; - if (!block.isNull(position)) { - int length = lengths[i]; - Slice slice = block.getSlice(position, 0, length); - slice.getBytes(0, bytes, targetOffsets[targetOffsetsIndex + i], length); + if (rawSlice.hasByteArray()) { + byte[] base = rawSlice.byteArray(); + int byteArrayOffset = rawSlice.byteArrayOffset(); + for (int i = 0; i < count; i++) { + System.arraycopy(base, byteArrayOffset + sourceOffsets[i], bytes, offsets[positionCount + i], lengths[i]); + } + } + else { + for (int i = 0; i < count; i++) { + rawSlice.getBytes(sourceOffsets[i], bytes, offsets[positionCount + i], lengths[i]); } } @@ -184,25 +194,75 @@ private void copyBytes(Block block, int[] lengths, int[] positions, int count, i } /** - * Copy {@code length} bytes from {@code block}, at position {@code position} to {@code count} consecutive positions in the {@link #bytes} array. + * Copy all bytes from {@code slice} to {@code count} consecutive positions in the {@link #bytes} array. */ - private void duplicateBytes(Block block, int position, int count) + private void duplicateBytes(Slice slice, int count) { - int length = block.getSliceLength(position); + int length = slice.length(); int newByteCount = toIntExact((long) count * length); int startOffset = getCurrentOffset(); - ensureBytesCapacity(startOffset + newByteCount); + ensureExtraBytesCapacity(newByteCount); + + duplicateBytes(slice, bytes, startOffset, count); - Slice slice = block.getSlice(position, 0, length); + int currentStartOffset = startOffset + length; for (int i = 0; i < count; i++) { - slice.getBytes(0, bytes, startOffset + (i * length), length); - offsets[positionCount + i + 1] = startOffset + ((i + 1) * length); + offsets[positionCount + i + 1] = currentStartOffset; + currentStartOffset += length; } positionCount += count; updateSize(count, newByteCount); } + /** + * Copy {@code length} bytes from {@code slice}, starting at offset {@code sourceOffset} to {@code count} consecutive positions in the {@link #bytes} array. + */ + @VisibleForTesting + static void duplicateBytes(Slice slice, byte[] bytes, int startOffset, int count) + { + int length = slice.length(); + if (length == 0) { + // nothing to copy + return; + } + // copy slice to the first position + slice.getBytes(0, bytes, startOffset, length); + int totalDuplicatedBytes = count * length; + int duplicatedBytes = length; + // copy every byte copied so far, doubling the number of bytes copied on evey iteration + while (duplicatedBytes * 2 <= totalDuplicatedBytes) { + System.arraycopy(bytes, startOffset, bytes, startOffset + duplicatedBytes, duplicatedBytes); + duplicatedBytes = duplicatedBytes * 2; + } + // copy the leftover + System.arraycopy(bytes, startOffset, bytes, startOffset + duplicatedBytes, totalDuplicatedBytes - duplicatedBytes); + } + + private void appendGenericBlock(IntArrayList positions, Block block) + { + int newByteCount = 0; + for (int i = 0; i < positions.size(); i++) { + int position = positions.getInt(i); + if (block.isNull(position)) { + offsets[positionCount + 1] = offsets[positionCount]; + valueIsNull[positionCount] = true; + hasNullValue = true; + } + else { + int length = block.getSliceLength(position); + ensureExtraBytesCapacity(length); + Slice slice = block.getSlice(position, 0, length); + slice.getBytes(0, bytes, offsets[positionCount], length); + offsets[positionCount + 1] = offsets[positionCount] + length; + hasNonNullValue = true; + newByteCount += length; + } + positionCount++; + } + updateSize(positions.size(), newByteCount); + } + private void reset() { initialEntryCount = calculateBlockResetSize(positionCount); @@ -228,12 +288,13 @@ private void updateSize(long positionsSize, int bytesWritten) sizeInBytes += (SIZE_OF_BYTE + SIZE_OF_INT) * positionsSize + bytesWritten; } - private void ensureBytesCapacity(int bytesCapacity) + private void ensureExtraBytesCapacity(int extraBytesCapacity) { - if (bytes.length < bytesCapacity) { + int totalBytesCapacity = getCurrentOffset() + extraBytesCapacity; + if (bytes.length < totalBytesCapacity) { int newBytesLength = Math.max(bytes.length, initialBytesSize); - if (bytesCapacity > newBytesLength) { - newBytesLength = Math.max(bytesCapacity, calculateNewArraySize(newBytesLength)); + if (totalBytesCapacity > newBytesLength) { + newBytesLength = Math.max(totalBytesCapacity, calculateNewArraySize(newBytesLength)); } bytes = Arrays.copyOf(bytes, newBytesLength); updateRetainedSize(); diff --git a/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java index dd6b9ef3ceb1..e413265e509b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java @@ -117,9 +117,7 @@ public boolean process() if (produceLazyBlock) { return true; } - else { - return processInternal(); - } + return processInternal(); } private boolean processInternal() @@ -200,10 +198,8 @@ public Block getResult() return result.getLoadedBlock(); }); } - else { - checkState(result != null, "result has not been generated"); - return result; - } + checkState(result != null, "result has not been generated"); + return result; } private void setupDictionaryBlockProjection() diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java index 30fb04145069..9abc95f268dd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java @@ -166,27 +166,25 @@ private static ChoicesSpecializedSqlScalarFunction specializeArrayJoin( methodHandle.bindTo(null), Optional.of(STATE_FACTORY)); } - else { - try { - InvocationConvention convention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), NULLABLE_RETURN, true, false); - MethodHandle cast = functionDependencies.getCastImplementation(type, VARCHAR, convention).getMethodHandle(); - - // if the cast doesn't take a ConnectorSession, create an adapter that drops the provided session - if (cast.type().parameterArray()[0] != ConnectorSession.class) { - cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class); - } + try { + InvocationConvention convention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), NULLABLE_RETURN, true, false); + MethodHandle cast = functionDependencies.getCastImplementation(type, VARCHAR, convention).getMethodHandle(); - MethodHandle target = MethodHandles.insertArguments(methodHandle, 0, cast); - return new ChoicesSpecializedSqlScalarFunction( - boundSignature, - FAIL_ON_NULL, - argumentConventions, - target, - Optional.of(STATE_FACTORY)); - } - catch (TrinoException e) { - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Input type %s not supported", type), e); + // if the cast doesn't take a ConnectorSession, create an adapter that drops the provided session + if (cast.type().parameterArray()[0] != ConnectorSession.class) { + cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class); } + + MethodHandle target = MethodHandles.insertArguments(methodHandle, 0, cast); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + FAIL_ON_NULL, + argumentConventions, + target, + Optional.of(STATE_FACTORY)); + } + catch (TrinoException e) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Input type %s not supported", type), e); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/BitwiseFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/BitwiseFunctions.java index 0f861922f2d2..d86e3ce0ba61 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/BitwiseFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/BitwiseFunctions.java @@ -202,9 +202,7 @@ public static long bitwiseRightShiftArithmeticTinyint(@SqlType(StandardTypes.TIN if (value >= 0) { return 0L; } - else { - return -1L; - } + return -1L; } return preserveSign(value, TINYINT_MASK, TINYINT_SIGNED_BIT) >> shift; } @@ -218,9 +216,7 @@ public static long bitwiseRightShiftArithmeticSmallint(@SqlType(StandardTypes.SM if (value >= 0) { return 0L; } - else { - return -1L; - } + return -1L; } return preserveSign(value, SMALLINT_MASK, SMALLINT_SIGNED_BIT) >> shift; } @@ -234,9 +230,7 @@ public static long bitwiseRightShiftArithmeticInteger(@SqlType(StandardTypes.INT if (value >= 0) { return 0L; } - else { - return -1L; - } + return -1L; } return preserveSign(value, INTEGER_MASK, INTEGER_SIGNED_BIT) >> shift; } @@ -250,9 +244,7 @@ public static long bitwiseRightShiftArithmeticBigint(@SqlType(StandardTypes.BIGI if (value >= 0) { return 0L; } - else { - return -1L; - } + return -1L; } return value >> shift; } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/CharacterStringCasts.java b/core/trino-main/src/main/java/io/trino/operator/scalar/CharacterStringCasts.java index 5fcddee99505..8cf6809ea93e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/CharacterStringCasts.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/CharacterStringCasts.java @@ -46,9 +46,7 @@ public static Slice varcharToVarcharCast(@LiteralParameter("x") Long x, @Literal if (x > y) { return truncateToLength(slice, y.intValue()); } - else { - return slice; - } + return slice; } @ScalarOperator(OperatorType.CAST) @@ -59,9 +57,7 @@ public static Slice charToCharCast(@LiteralParameter("x") Long x, @LiteralParame if (x > y) { return truncateToLength(slice, y.intValue()); } - else { - return slice; - } + return slice; } @ScalarOperator(OperatorType.CAST) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java index 2b87b118580d..f81c688d2693 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java @@ -74,10 +74,8 @@ public Slice getElement(int i) if (elements.isNull(i)) { return null; } - else { - int sliceLength = elements.getSliceLength(i); - return elements.getSlice(i, 0, sliceLength); - } + int sliceLength = elements.getSliceLength(i); + return elements.getSlice(i, 0, sliceLength); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpFunctions.java index 3a943da09600..2eabd99b2587 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpFunctions.java @@ -74,14 +74,10 @@ private static int getNextStart(Slice source, Matcher matcher) if (matcher.getBegin() < source.length()) { return matcher.getEnd() + lengthOfCodePointFromStartByte(source.getByte(matcher.getBegin())); } - else { - // last match is empty and we matched end of source, move past the source length to terminate the loop - return matcher.getEnd() + 1; - } - } - else { - return matcher.getEnd(); + // last match is empty and we matched end of source, move past the source length to terminate the loop + return matcher.getEnd() + 1; } + return matcher.getEnd(); } @Description("Removes substrings matching a regular expression") diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java index 54ae82504367..ece2aae99f0b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java @@ -179,21 +179,19 @@ private MethodHandle nullChecker(Class javaType) if (javaType == Long.class) { return CHECK_LONG_IS_NOT_NULL; } - else if (javaType == Double.class) { + if (javaType == Double.class) { return CHECK_DOUBLE_IS_NOT_NULL; } - else if (javaType == Boolean.class) { + if (javaType == Boolean.class) { return CHECK_BOOLEAN_IS_NOT_NULL; } - else if (javaType == Slice.class) { + if (javaType == Slice.class) { return CHECK_SLICE_IS_NOT_NULL; } - else if (javaType == Block.class) { + if (javaType == Block.class) { return CHECK_BLOCK_IS_NOT_NULL; } - else { - throw new IllegalArgumentException("Unknown java type " + javaType); - } + throw new IllegalArgumentException("Unknown java type " + javaType); } @UsedByGeneratedCode diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControl.java b/core/trino-main/src/main/java/io/trino/security/AccessControl.java index c266ba9373a0..6b53d9320718 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControl.java @@ -379,6 +379,13 @@ default void checkCanSetViewAuthorization(SecurityContext context, QualifiedObje */ void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption); + /** + * Check if identity is allowed to create a view that executes the function. + * + * @throws AccessDeniedException if not allowed + */ + void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption); + /** * Check if identity is allowed to grant a privilege to the grantee on the specified schema. * diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java index 9946337b0ae1..40419aa91208 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java @@ -843,6 +843,21 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContex grantOption)); } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContext, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption) + { + requireNonNull(securityContext, "securityContext is null"); + requireNonNull(functionKind, "functionKind is null"); + requireNonNull(functionName, "functionName is null"); + + systemAuthorizationCheck(control -> control.checkCanGrantExecuteFunctionPrivilege( + securityContext.toSystemSecurityContext(), + functionKind, + functionName.asCatalogSchemaRoutineName(), + new TrinoPrincipal(PrincipalType.USER, grantee.getUser()), + grantOption)); + } + @Override public void checkCanGrantSchemaPrivilege(SecurityContext securityContext, Privilege privilege, CatalogSchemaName schemaName, TrinoPrincipal grantee, boolean grantOption) { diff --git a/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java b/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java index f96bcf927300..a3e34dddb459 100644 --- a/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java @@ -271,6 +271,11 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, Strin { } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption) + { + } + @Override public void checkCanGrantSchemaPrivilege(SecurityContext context, Privilege privilege, CatalogSchemaName schemaName, TrinoPrincipal grantee, boolean grantOption) { diff --git a/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java b/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java index bb5fbb7e2d8a..9243fe23347b 100644 --- a/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java @@ -379,6 +379,12 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, Strin denyGrantExecuteFunctionPrivilege(functionName, context.getIdentity(), grantee); } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption) + { + denyGrantExecuteFunctionPrivilege(functionName.toString(), context.getIdentity(), grantee); + } + @Override public void checkCanGrantSchemaPrivilege(SecurityContext context, Privilege privilege, CatalogSchemaName schemaName, TrinoPrincipal grantee, boolean grantOption) { diff --git a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java index e86f97dd409b..ba1198188207 100644 --- a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java @@ -335,6 +335,12 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, Strin delegate().checkCanGrantExecuteFunctionPrivilege(context, functionName, grantee, grantOption); } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption) + { + delegate().checkCanGrantExecuteFunctionPrivilege(context, functionKind, functionName, grantee, grantOption); + } + @Override public void checkCanGrantSchemaPrivilege(SecurityContext context, Privilege privilege, CatalogSchemaName schemaName, TrinoPrincipal grantee, boolean grantOption) { diff --git a/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java index 62a041df5445..f2b13cbf0c65 100644 --- a/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java @@ -15,6 +15,7 @@ import io.trino.metadata.QualifiedObjectName; import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.function.FunctionKind; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.Identity; import io.trino.spi.security.ViewExpression; @@ -72,6 +73,12 @@ public void checkCanExecuteFunction(SecurityContext context, String functionName wrapAccessDeniedException(() -> delegate.checkCanGrantExecuteFunctionPrivilege(context, functionName, invoker, false)); } + @Override + public void checkCanExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName) + { + wrapAccessDeniedException(() -> delegate.checkCanGrantExecuteFunctionPrivilege(context, functionKind, functionName, invoker, false)); + } + @Override public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption) { diff --git a/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java b/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java index cd4a7d9ca8bc..78f8e31d5446 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java @@ -108,9 +108,7 @@ public NodeState getServerState() if (shutdownHandler.isShutdownRequested()) { return SHUTTING_DOWN; } - else { - return ACTIVE; - } + return ACTIVE; } @ResourceSecurity(PUBLIC) diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index 6406a12e047c..c107956ae9ca 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -187,7 +187,7 @@ public static Query create( Query result = new Query(session, slug, queryManager, queryInfoUrl, exchangeDataSource, dataProcessorExecutor, timeoutExecutor, blockEncodingSerde); - result.queryManager.addOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); + result.queryManager.setOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); result.queryManager.addStateChangeListener(result.getQueryId(), state -> { // Wait for the query info to become available and close the exchange client if there is no output stage for the query results to be pulled from. @@ -582,7 +582,7 @@ private synchronized void setQueryOutputInfo(QueryExecution.QueryOutputInfo outp types = outputInfo.getColumnTypes(); } - outputInfo.getInputs().forEach(exchangeDataSource::addInput); + outputInfo.drainInputs(exchangeDataSource::addInput); if (outputInfo.isNoMoreInputs()) { exchangeDataSource.noMoreInputs(); } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index 50d9006026e1..93db6b5a9f39 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -129,10 +129,7 @@ public final class HttpRemoteTask // The version of dynamic filters that has been successfully sent to the worker private final AtomicLong sentDynamicFiltersVersion = new AtomicLong(INITIAL_DYNAMIC_FILTERS_VERSION); - @GuardedBy("pendingRequestsCounter") - private Future currentRequest; - @GuardedBy("pendingRequestsCounter") - private long currentRequestStartNanos; + private final AtomicReference> currentRequest = new AtomicReference<>(); @GuardedBy("this") private final SetMultimap pendingSplits = HashMultimap.create(); @@ -165,7 +162,7 @@ public final class HttpRemoteTask private final RequestErrorTracker updateErrorTracker; - private final AtomicInteger pendingRequestsCounter = new AtomicInteger(1); + private final AtomicInteger pendingRequestsCounter = new AtomicInteger(0); private final AtomicBoolean sendPlan = new AtomicBoolean(true); private final PartitionedSplitCountTracker partitionedSplitCountTracker; @@ -352,7 +349,7 @@ public void start() try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { // to start we just need to trigger an update started.set(true); - scheduleUpdate(); + triggerUpdate(); dynamicFiltersFetcher.start(); taskStatusFetcher.start(); @@ -578,6 +575,10 @@ private void scheduleUpdate() private void triggerUpdate() { + if (!started.get()) { + // task has not started yet + return; + } if (pendingRequestsCounter.getAndIncrement() == 0) { // schedule update if this is the first update requested scheduleUpdate(); @@ -586,73 +587,59 @@ private void triggerUpdate() private void sendUpdate() { - synchronized (pendingRequestsCounter) { - TaskStatus taskStatus = getTaskStatus(); - // don't update if the task hasn't been started yet or if it is already finished - if (!started.get() || taskStatus.getState().isDone()) { - return; - } - - int currentPendingRequestsCounter = pendingRequestsCounter.get(); - if (currentPendingRequestsCounter == 0) { - return; - } + TaskStatus taskStatus = getTaskStatus(); + // don't update if the task is already finished + if (taskStatus.getState().isDone()) { + return; + } + checkState(started.get()); - // if there is a request already running, wait for it to complete - // currentRequest is always cleared when request is complete - if (currentRequest != null) { - return; - } + int currentPendingRequestsCounter = pendingRequestsCounter.get(); + checkState(currentPendingRequestsCounter > 0, "sendUpdate shouldn't be called without pending requests"); - // if throttled due to error, asynchronously wait for timeout and try again - ListenableFuture errorRateLimit = updateErrorTracker.acquireRequestPermit(); - if (!errorRateLimit.isDone()) { - errorRateLimit.addListener(this::sendUpdate, executor); - return; - } + // if throttled due to error, asynchronously wait for timeout and try again + ListenableFuture errorRateLimit = updateErrorTracker.acquireRequestPermit(); + if (!errorRateLimit.isDone()) { + errorRateLimit.addListener(this::sendUpdate, executor); + return; + } - List splitAssignments = getSplitAssignments(); - VersionedDynamicFilterDomains dynamicFilterDomains = outboundDynamicFiltersCollector.acknowledgeAndGetNewDomains(sentDynamicFiltersVersion.get()); - - // Workers don't need the embedded JSON representation when the fragment is sent - Optional fragment = sendPlan.get() ? Optional.of(planFragment.withoutEmbeddedJsonRepresentation()) : Optional.empty(); - TaskUpdateRequest updateRequest = new TaskUpdateRequest( - session.toSessionRepresentation(), - session.getIdentity().getExtraCredentials(), - fragment, - splitAssignments, - outputBuffers.get(), - dynamicFilterDomains.getDynamicFilterDomains()); - byte[] taskUpdateRequestJson = taskUpdateRequestCodec.toJsonBytes(updateRequest); - if (fragment.isPresent()) { - stats.updateWithPlanBytes(taskUpdateRequestJson.length); - } - if (!dynamicFilterDomains.getDynamicFilterDomains().isEmpty()) { - stats.updateWithDynamicFilterBytes(taskUpdateRequestJson.length); - } + List splitAssignments = getSplitAssignments(); + VersionedDynamicFilterDomains dynamicFilterDomains = outboundDynamicFiltersCollector.acknowledgeAndGetNewDomains(sentDynamicFiltersVersion.get()); - HttpUriBuilder uriBuilder = getHttpUriBuilder(taskStatus); - Request request = preparePost() - .setUri(uriBuilder.build()) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.JSON_UTF_8.toString()) - .setBodyGenerator(createStaticBodyGenerator(taskUpdateRequestJson)) - .build(); + // Workers don't need the embedded JSON representation when the fragment is sent + Optional fragment = sendPlan.get() ? Optional.of(planFragment.withoutEmbeddedJsonRepresentation()) : Optional.empty(); + TaskUpdateRequest updateRequest = new TaskUpdateRequest( + session.toSessionRepresentation(), + session.getIdentity().getExtraCredentials(), + fragment, + splitAssignments, + outputBuffers.get(), + dynamicFilterDomains.getDynamicFilterDomains()); + byte[] taskUpdateRequestJson = taskUpdateRequestCodec.toJsonBytes(updateRequest); + if (fragment.isPresent()) { + stats.updateWithPlanBytes(taskUpdateRequestJson.length); + } + if (!dynamicFilterDomains.getDynamicFilterDomains().isEmpty()) { + stats.updateWithDynamicFilterBytes(taskUpdateRequestJson.length); + } - updateErrorTracker.startRequest(); + HttpUriBuilder uriBuilder = getHttpUriBuilder(taskStatus); + Request request = preparePost() + .setUri(uriBuilder.build()) + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.JSON_UTF_8.toString()) + .setBodyGenerator(createStaticBodyGenerator(taskUpdateRequestJson)) + .build(); - ListenableFuture> future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec)); - currentRequest = future; - currentRequestStartNanos = System.nanoTime(); + updateErrorTracker.startRequest(); - // if pendingRequestsCounter is still non-zero (e.g. because triggerUpdate was called in the meantime) - // then the request Future callback will send a new update via sendUpdate method call - pendingRequestsCounter.addAndGet(-currentPendingRequestsCounter); + ListenableFuture> future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec)); + checkState(currentRequest.getAndSet(future) == null, "There should be no previous request running"); - Futures.addCallback( - future, - new SimpleHttpResponseHandler<>(new UpdateResponseHandler(splitAssignments, dynamicFilterDomains.getVersion()), request.getUri(), stats), - executor); - } + Futures.addCallback( + future, + new SimpleHttpResponseHandler<>(new UpdateResponseHandler(splitAssignments, dynamicFilterDomains.getVersion(), System.nanoTime(), currentPendingRequestsCounter), request.getUri(), stats), + executor); } private synchronized List getSplitAssignments() @@ -710,12 +697,9 @@ private void cleanUpTask() outboundDynamicFiltersCollector.acknowledge(Long.MAX_VALUE); // cancel pending request - synchronized (pendingRequestsCounter) { - if (currentRequest != null) { - currentRequest.cancel(true); - currentRequest = null; - currentRequestStartNanos = 0; - } + Future request = currentRequest.getAndSet(null); + if (request != null) { + request.cancel(true); } taskStatusFetcher.stop(); @@ -915,34 +899,33 @@ private class UpdateResponseHandler { private final List splitAssignments; private final long currentRequestDynamicFiltersVersion; + private final long currentRequestStartNanos; + private final int currentPendingRequestsCounter; - private UpdateResponseHandler(List splitAssignments, long currentRequestDynamicFiltersVersion) + private UpdateResponseHandler(List splitAssignments, long currentRequestDynamicFiltersVersion, long currentRequestStartNanos, int currentPendingRequestsCounter) { this.splitAssignments = ImmutableList.copyOf(requireNonNull(splitAssignments, "splitAssignments is null")); this.currentRequestDynamicFiltersVersion = currentRequestDynamicFiltersVersion; + this.currentRequestStartNanos = currentRequestStartNanos; + this.currentPendingRequestsCounter = currentPendingRequestsCounter; } @Override public void success(TaskInfo value) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { - try { - sentDynamicFiltersVersion.set(currentRequestDynamicFiltersVersion); - // Remove dynamic filters which were successfully sent to free up memory - outboundDynamicFiltersCollector.acknowledge(currentRequestDynamicFiltersVersion); - sendPlan.set(value.isNeedsPlan()); - long currentRequestStartNanos; - synchronized (pendingRequestsCounter) { - currentRequest = null; - currentRequestStartNanos = HttpRemoteTask.this.currentRequestStartNanos; - } - updateStats(currentRequestStartNanos); - processTaskUpdate(value, splitAssignments); - updateErrorTracker.requestSucceeded(); - } - finally { - sendUpdate(); + sentDynamicFiltersVersion.set(currentRequestDynamicFiltersVersion); + // Remove dynamic filters which were successfully sent to free up memory + outboundDynamicFiltersCollector.acknowledge(currentRequestDynamicFiltersVersion); + sendPlan.set(value.isNeedsPlan()); + currentRequest.set(null); + updateStats(); + updateErrorTracker.requestSucceeded(); + if (pendingRequestsCounter.addAndGet(-currentPendingRequestsCounter) > 0) { + // schedule an update because triggerUpdate was called in the meantime + scheduleUpdate(); } + processTaskUpdate(value, splitAssignments); } } @@ -951,21 +934,17 @@ public void failed(Throwable cause) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { try { - long currentRequestStartNanos; - synchronized (pendingRequestsCounter) { - currentRequest = null; - currentRequestStartNanos = HttpRemoteTask.this.currentRequestStartNanos; - } - updateStats(currentRequestStartNanos); - - // on failure assume we need to update again - pendingRequestsCounter.incrementAndGet(); + currentRequest.set(null); + updateStats(); // if task not already done, record error TaskStatus taskStatus = getTaskStatus(); if (!taskStatus.getState().isDone()) { updateErrorTracker.requestFailed(cause); } + + // on failure assume we need to update again + scheduleUpdate(); } catch (Error e) { fail(e); @@ -974,9 +953,6 @@ public void failed(Throwable cause) catch (RuntimeException e) { fail(e); } - finally { - sendUpdate(); - } } } @@ -988,7 +964,7 @@ public void fatal(Throwable cause) } } - private void updateStats(long currentRequestStartNanos) + private void updateStats() { Duration requestRoundTrip = Duration.nanosSince(currentRequestStartNanos); stats.updateRoundTripMillis(requestRoundTrip.toMillis()); diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/JweTokenSerializer.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/JweTokenSerializer.java index eda58cab016d..7c9b4f47a435 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/JweTokenSerializer.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/JweTokenSerializer.java @@ -39,7 +39,6 @@ import java.time.Clock; import java.util.Date; import java.util.Map; -import java.util.Optional; import static com.google.common.base.Preconditions.checkState; import static io.trino.server.security.jwt.JwtUtil.newJwtBuilder; @@ -122,11 +121,7 @@ public String serialize(TokenPair tokenPair) { requireNonNull(tokenPair, "tokenPair is null"); - Optional> accessTokenClaims = client.getClaims(tokenPair.getAccessToken()); - if (accessTokenClaims.isEmpty()) { - throw new IllegalArgumentException("Claims are missing"); - } - Map claims = accessTokenClaims.get(); + Map claims = client.getClaims(tokenPair.getAccessToken()).orElseThrow(() -> new IllegalArgumentException("Claims are missing")); if (!claims.containsKey(principalField)) { throw new IllegalArgumentException(format("%s field is missing", principalField)); } diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OidcDiscovery.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OidcDiscovery.java index 4ecaa614bc92..e35d1efae9d7 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OidcDiscovery.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OidcDiscovery.java @@ -97,9 +97,7 @@ private OAuth2ServerConfig parseConfigurationResponse(HTTPResponse response) if (statusCode < 400 || statusCode >= 500 || statusCode == REQUEST_TIMEOUT.code() || statusCode == TOO_MANY_REQUESTS.code()) { throw new RuntimeException("Invalid response from OpenID Metadata endpoint: " + statusCode); } - else { - throw new IllegalStateException(format("Invalid response from OpenID Metadata endpoint. Expected response code to be %s, but was %s", OK.code(), statusCode)); - } + throw new IllegalStateException(format("Invalid response from OpenID Metadata endpoint. Expected response code to be %s, but was %s", OK.code(), statusCode)); } return readConfiguration(response.getContent()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java index 73e545326784..4936fffd46fe 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java @@ -99,7 +99,7 @@ public static OptionalInt canonicalizationAwareHash(Node node) if (node instanceof Identifier) { return OptionalInt.of(((Identifier) node).getCanonicalValue().hashCode()); } - else if (node.getChildren().isEmpty()) { + if (node.getChildren().isEmpty()) { return OptionalInt.of(node.hashCode()); } return OptionalInt.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index 78948b2ad36a..afb1c4b4d82f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -722,8 +722,7 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA throw semanticException(COLUMN_NOT_FOUND, node, "Column %s prefixed with label %s cannot be resolved", unlabeledName, label); } Identifier unlabeled = qualifiedName.getOriginalParts().get(1); - Optional resolvedField = context.getContext().getScope().tryResolveField(node, unlabeledName); - if (resolvedField.isEmpty()) { + if (context.getContext().getScope().tryResolveField(node, unlabeledName).isEmpty()) { throw semanticException(COLUMN_NOT_FOUND, node, "Column %s prefixed with label %s cannot be resolved", unlabeledName, label); } // Correlation is not allowed in pattern recognition context. Visitor's context for pattern recognition has CorrelationSupport.DISALLOWED, @@ -901,13 +900,8 @@ private void coerceCaseOperandToToSingleType(SimpleCaseExpression node, Stackabl Type whenOperandType = process(whenOperand, context); whenOperandTypes.add(whenOperandType); - Optional operandCommonType = typeCoercion.getCommonSuperType(commonType, whenOperandType); - - if (operandCommonType.isEmpty()) { - throw semanticException(TYPE_MISMATCH, whenOperand, "CASE operand type does not match WHEN clause operand type: %s vs %s", operandType, whenOperandType); - } - - commonType = operandCommonType.get(); + commonType = typeCoercion.getCommonSuperType(commonType, whenOperandType) + .orElseThrow(() -> semanticException(TYPE_MISMATCH, whenOperand, "CASE operand type does not match WHEN clause operand type: %s vs %s", operandType, whenOperandType)); } if (commonType != operandType) { @@ -1505,10 +1499,8 @@ else if (frame.getType() == GROUPS) { private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type boundType, StackableAstVisitorContext context, ResolvedWindow window, Node originalNode) { - if (window.getOrderBy().isEmpty()) { - throw semanticException(MISSING_ORDER_BY, originalNode, "Window frame of type RANGE PRECEDING or FOLLOWING requires ORDER BY"); - } - OrderBy orderBy = window.getOrderBy().get(); + OrderBy orderBy = window.getOrderBy() + .orElseThrow(() -> semanticException(MISSING_ORDER_BY, originalNode, "Window frame of type RANGE PRECEDING or FOLLOWING requires ORDER BY")); if (orderBy.getSortItems().size() != 1) { throw semanticException(INVALID_ORDER_BY, orderBy, "Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY (actual: %s)", orderBy.getSortItems().size()); } @@ -1824,9 +1816,7 @@ private ArgumentLabel validateLabelConsistency(FunctionCall node, boolean labelR if (labelRequired) { throw semanticException(INVALID_ARGUMENTS, node, "Pattern navigation function %s must contain at least one column reference or CLASSIFIER()", name); } - else { - return ArgumentLabel.noLabel(); - } + return ArgumentLabel.noLabel(); } // Label consistency rules: @@ -1884,12 +1874,12 @@ private ArgumentLabel validateLabelConsistency(FunctionCall node, boolean labelR if (!inputColumnLabels.isEmpty()) { return ArgumentLabel.explicitLabel(getOnlyElement(inputColumnLabels)); } - else if (!classifierLabels.isEmpty()) { + if (!classifierLabels.isEmpty()) { return getOnlyElement(classifierLabels) .map(ArgumentLabel::explicitLabel) .orElse(ArgumentLabel.universalLabel()); } - else if (!unlabeledInputColumns.isEmpty()) { + if (!unlabeledInputColumns.isEmpty()) { return ArgumentLabel.universalLabel(); } return ArgumentLabel.noLabel(); @@ -2516,9 +2506,7 @@ public Type visitGroupingOperation(GroupingOperation node, StackableAstVisitorCo if (node.getGroupingColumns().size() <= MAX_NUMBER_GROUPING_ARGUMENTS_INTEGER) { return setExpressionType(node, INTEGER); } - else { - return setExpressionType(node, BIGINT); - } + return setExpressionType(node, BIGINT); } @Override @@ -2815,12 +2803,10 @@ private ResolvedFunction getInputFunction(Type type, JsonFormat format, Node nod if (UNKNOWN.equals(type) || isCharacterStringType(type)) { yield QualifiedName.of(VARCHAR_TO_JSON); } - else if (isStringType(type)) { + if (isStringType(type)) { yield QualifiedName.of(VARBINARY_TO_JSON); } - else { - throw semanticException(TYPE_MISMATCH, node, format("Cannot read input of type %s as JSON using formatting %s", type, format)); - } + throw semanticException(TYPE_MISMATCH, node, format("Cannot read input of type %s as JSON using formatting %s", type, format)); } case UTF8 -> QualifiedName.of(VARBINARY_UTF8_TO_JSON); case UTF16 -> QualifiedName.of(VARBINARY_UTF16_TO_JSON); @@ -2842,12 +2828,10 @@ private ResolvedFunction getOutputFunction(Type type, JsonFormat format, Node no if (isCharacterStringType(type)) { yield QualifiedName.of(JSON_TO_VARCHAR); } - else if (isStringType(type)) { + if (isStringType(type)) { yield QualifiedName.of(JSON_TO_VARBINARY); } - else { - throw semanticException(TYPE_MISMATCH, node, format("Cannot output JSON value as %s using formatting %s", type, format)); - } + throw semanticException(TYPE_MISMATCH, node, format("Cannot output JSON value as %s using formatting %s", type, format)); } case UTF8 -> { if (!VARBINARY.equals(type)) { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/RelationId.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/RelationId.java index e214bee84ae7..43a88a6ef454 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/RelationId.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/RelationId.java @@ -86,11 +86,9 @@ public String toString() .addValue(format("x%08x", identityHashCode(this))) .toString(); } - else { - return toStringHelper(this) - .addValue(sourceNode.get().getClass().getSimpleName()) - .addValue(format("x%08x", identityHashCode(sourceNode.get()))) - .toString(); - } + return toStringHelper(this) + .addValue(sourceNode.get().getClass().getSimpleName()) + .addValue(format("x%08x", identityHashCode(sourceNode.get()))) + .toString(); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Scope.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Scope.java index 5f25fe8da060..138299100875 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Scope.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Scope.java @@ -248,7 +248,7 @@ private Optional resolveField(Expression node, QualifiedName name if (matches.size() > 1) { throw ambiguousAttributeException(node, name); } - else if (matches.size() == 1) { + if (matches.size() == 1) { int parentFieldCount = getLocalParent() .map(Scope::getLocalScopeFieldCount) .orElse(0); @@ -256,18 +256,16 @@ else if (matches.size() == 1) { Field field = getOnlyElement(matches); return Optional.of(asResolvedField(field, parentFieldCount, local)); } - else { - if (isColumnReference(name, relation)) { - return Optional.empty(); - } - if (parent.isPresent()) { - if (queryBoundary) { - return parent.get().resolveField(node, name, false); - } - return parent.get().resolveField(node, name, local); - } + if (isColumnReference(name, relation)) { return Optional.empty(); } + if (parent.isPresent()) { + if (queryBoundary) { + return parent.get().resolveField(node, name, false); + } + return parent.get().resolveField(node, name, local); + } + return Optional.empty(); } public ResolvedField getField(int index) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index ebb6d08f0773..cf734b1c36f6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -613,11 +613,8 @@ protected Scope visitInsert(Insert insert, Optional scope) protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMaterializedView, Optional scope) { QualifiedObjectName name = createQualifiedObjectName(session, refreshMaterializedView, refreshMaterializedView.getName()); - Optional optionalView = metadata.getMaterializedView(session, name); - - if (optionalView.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Materialized view '%s' does not exist", name); - } + MaterializedViewDefinition view = metadata.getMaterializedView(session, name) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Materialized view '%s' does not exist", name)); accessControl.checkCanRefreshMaterializedView(session.toSecurityContext(), name); analysis.setUpdateType("REFRESH MATERIALIZED VIEW"); @@ -631,38 +628,33 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate return createAndAssignScope(refreshMaterializedView, scope); } - Optional storageName = getMaterializedViewStorageTableName(optionalView.get()); - - if (storageName.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Storage Table '%s' for materialized view '%s' does not exist", storageName, name); - } + QualifiedName storageName = getMaterializedViewStorageTableName(view) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Storage Table for materialized view '%s' does not exist", name)); - QualifiedObjectName targetTable = createQualifiedObjectName(session, refreshMaterializedView, storageName.get()); + QualifiedObjectName targetTable = createQualifiedObjectName(session, refreshMaterializedView, storageName); checkStorageTableNotRedirected(targetTable); // analyze the query that creates the data - Query query = parseView(optionalView.get().getOriginalSql(), name, refreshMaterializedView); + Query query = parseView(view.getOriginalSql(), name, refreshMaterializedView); Scope queryScope = process(query, scope); // verify the insert destination columns match the query - Optional targetTableHandle = metadata.getTableHandle(session, targetTable); - if (targetTableHandle.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Table '%s' does not exist", targetTable); - } + TableHandle targetTableHandle = metadata.getTableHandle(session, targetTable) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Table '%s' does not exist", targetTable)); analysis.setSkipMaterializedViewRefresh(metadata.getMaterializedViewFreshness(session, name).isMaterializedViewFresh()); - TableMetadata tableMetadata = metadata.getTableMetadata(session, targetTableHandle.get()); + TableMetadata tableMetadata = metadata.getTableMetadata(session, targetTableHandle); List insertColumns = tableMetadata.getColumns().stream() .filter(column -> !column.isHidden()) .map(ColumnMetadata::getName) .collect(toImmutableList()); - Map columnHandles = metadata.getColumnHandles(session, targetTableHandle.get()); + Map columnHandles = metadata.getColumnHandles(session, targetTableHandle); analysis.setRefreshMaterializedView(new Analysis.RefreshMaterializedViewAnalysis( refreshMaterializedView.getTable(), - targetTableHandle.get(), query, + targetTableHandle, query, insertColumns.stream().map(columnHandles::get).collect(toImmutableList()))); List tableTypes = insertColumns.stream() @@ -756,6 +748,9 @@ protected Scope visitDelete(Delete node, Optional scope) { Table table = node.getTable(); QualifiedObjectName originalName = createQualifiedObjectName(session, table, table.getName()); + if (metadata.isMaterializedView(session, originalName)) { + throw semanticException(NOT_SUPPORTED, node, "Deleting from materialized views is not supported"); + } if (metadata.isView(session, originalName)) { throw semanticException(NOT_SUPPORTED, node, "Deleting from views is not supported"); } @@ -921,10 +916,8 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional names = new HashSet<>(); for (Field field : descriptor.getVisibleFields()) { - Optional fieldName = field.getName(); - if (fieldName.isEmpty()) { - throw semanticException(MISSING_COLUMN_NAME, node, "Column name not specified at position %s", descriptor.indexOf(field) + 1); - } - if (!names.add(fieldName.get())) { - throw semanticException(DUPLICATE_COLUMN_NAME, node, "Column name '%s' specified more than once", fieldName.get()); + String fieldName = field.getName() + .orElseThrow(() -> semanticException(MISSING_COLUMN_NAME, node, "Column name not specified at position %s", descriptor.indexOf(field) + 1)); + if (!names.add(fieldName)) { + throw semanticException(DUPLICATE_COLUMN_NAME, node, "Column name '%s' specified more than once", fieldName); } if (field.getType().equals(UNKNOWN)) { - throw semanticException(COLUMN_TYPE_UNKNOWN, node, "Column type is unknown: %s", fieldName.get()); + throw semanticException(COLUMN_TYPE_UNKNOWN, node, "Column type is unknown: %s", fieldName); } } } @@ -1794,23 +1785,16 @@ protected Scope visitTable(Table table, Optional scope) if (optionalMaterializedView.isPresent()) { if (metadata.getMaterializedViewFreshness(session, name).isMaterializedViewFresh()) { // If materialized view is current, answer the query using the storage table - Optional storageName = getMaterializedViewStorageTableName(optionalMaterializedView.get()); - if (storageName.isEmpty()) { - throw semanticException(INVALID_VIEW, table, "Materialized view '%s' is fresh but does not have storage table name", name); - } - QualifiedObjectName storageTableName = createQualifiedObjectName(session, table, storageName.get()); + QualifiedName storageName = getMaterializedViewStorageTableName(optionalMaterializedView.get()) + .orElseThrow(() -> semanticException(INVALID_VIEW, table, "Materialized view '%s' is fresh but does not have storage table name", name)); + QualifiedObjectName storageTableName = createQualifiedObjectName(session, table, storageName); checkStorageTableNotRedirected(storageTableName); - Optional tableHandle = metadata.getTableHandle(session, storageTableName); - if (tableHandle.isEmpty()) { - throw semanticException(INVALID_VIEW, table, "Storage table '%s' does not exist", storageTableName); - } - - return createScopeForMaterializedView(table, name, scope, optionalMaterializedView.get(), tableHandle); - } - else { - // This is a stale materialized view and should be expanded like a logical view - return createScopeForMaterializedView(table, name, scope, optionalMaterializedView.get(), Optional.empty()); + TableHandle tableHandle = metadata.getTableHandle(session, storageTableName) + .orElseThrow(() -> semanticException(INVALID_VIEW, table, "Storage table '%s' does not exist", storageTableName)); + return createScopeForMaterializedView(table, name, scope, optionalMaterializedView.get(), Optional.of(tableHandle)); } + // This is a stale materialized view and should be expanded like a logical view + return createScopeForMaterializedView(table, name, scope, optionalMaterializedView.get(), Optional.empty()); } // This could be a reference to a logical view or a table @@ -2709,8 +2693,11 @@ protected Scope visitUpdate(Update update, Optional scope) { Table table = update.getTable(); QualifiedObjectName originalName = createQualifiedObjectName(session, table, table.getName()); + if (metadata.isMaterializedView(session, originalName)) { + throw semanticException(NOT_SUPPORTED, update, "Updating materialized views is not supported"); + } if (metadata.isView(session, originalName)) { - throw semanticException(NOT_SUPPORTED, update, "Updating through views is not supported"); + throw semanticException(NOT_SUPPORTED, update, "Updating views is not supported"); } RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, originalName); @@ -2812,8 +2799,11 @@ protected Scope visitMerge(Merge merge, Optional scope) Relation relation = merge.getTarget(); Table table = getMergeTargetTable(relation); QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + if (metadata.isMaterializedView(session, tableName)) { + throw semanticException(NOT_SUPPORTED, merge, "Merging into materialized views is not supported"); + } if (metadata.getView(session, tableName).isPresent()) { - throw semanticException(NOT_SUPPORTED, merge, "MERGE INTO a view is not supported"); + throw semanticException(NOT_SUPPORTED, merge, "Merging into views is not supported"); } TableHandle targetTableHandle = metadata.getTableHandle(session, tableName) @@ -3069,35 +3059,30 @@ private Scope analyzeJoinUsing(Join node, List columns, Optional leftField = left.tryResolveField(column); - Optional rightField = right.tryResolveField(column); - - if (leftField.isEmpty()) { - throw semanticException(COLUMN_NOT_FOUND, column, "Column '%s' is missing from left side of join", column.getValue()); - } - if (rightField.isEmpty()) { - throw semanticException(COLUMN_NOT_FOUND, column, "Column '%s' is missing from right side of join", column.getValue()); - } + ResolvedField leftField = left.tryResolveField(column) + .orElseThrow(() -> semanticException(COLUMN_NOT_FOUND, column, "Column '%s' is missing from left side of join", column.getValue())); + ResolvedField rightField = right.tryResolveField(column) + .orElseThrow(() -> semanticException(COLUMN_NOT_FOUND, column, "Column '%s' is missing from right side of join", column.getValue())); // ensure a comparison operator exists for the given types (applying coercions if necessary) try { metadata.resolveOperator(session, OperatorType.EQUAL, ImmutableList.of( - leftField.get().getType(), rightField.get().getType())); + leftField.getType(), rightField.getType())); } catch (OperatorNotFoundException e) { throw semanticException(TYPE_MISMATCH, column, e, "%s", e.getMessage()); } - Optional type = typeCoercion.getCommonSuperType(leftField.get().getType(), rightField.get().getType()); + Optional type = typeCoercion.getCommonSuperType(leftField.getType(), rightField.getType()); analysis.addTypes(ImmutableMap.of(NodeRef.of(column), type.orElseThrow())); joinFields.add(Field.newUnqualified(column.getValue(), type.get())); - leftJoinFields.add(leftField.get().getRelationFieldIndex()); - rightJoinFields.add(rightField.get().getRelationFieldIndex()); + leftJoinFields.add(leftField.getRelationFieldIndex()); + rightJoinFields.add(rightField.getRelationFieldIndex()); - recordColumnAccess(leftField.get().getField()); - recordColumnAccess(rightField.get().getField()); + recordColumnAccess(leftField.getField()); + recordColumnAccess(rightField.getField()); } ImmutableList.Builder outputs = ImmutableList.builder(); @@ -3178,15 +3163,12 @@ protected Scope visitValues(Values node, Optional scope) } // determine common super type of the rows - Optional partialSuperType = typeCoercion.getCommonSuperType(rowType, commonSuperType); - if (partialSuperType.isEmpty()) { - throw semanticException(TYPE_MISMATCH, - node, - "Values rows have mismatched types: %s vs %s", - rowTypes.get(0), - rowType); - } - commonSuperType = partialSuperType.get(); + commonSuperType = typeCoercion.getCommonSuperType(rowType, commonSuperType) + .orElseThrow(() -> semanticException(TYPE_MISMATCH, + node, + "Values rows have mismatched types: %s vs %s", + rowTypes.get(0), + rowType)); } // add coercions @@ -3723,22 +3705,20 @@ private void analyzeSelectAllColumns( QualifiedName prefix = asQualifiedName(expression); if (prefix != null) { // analyze prefix as an 'asterisked identifier chain' - Optional identifierChainBasis = scope.resolveAsteriskedIdentifierChainBasis(prefix, allColumns); - if (identifierChainBasis.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, allColumns, "Unable to resolve reference %s", prefix); - } - if (identifierChainBasis.get().getBasisType() == TABLE) { - RelationType relationType = identifierChainBasis.get().getRelationType().orElseThrow(); + AsteriskedIdentifierChainBasis identifierChainBasis = scope.resolveAsteriskedIdentifierChainBasis(prefix, allColumns) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, allColumns, "Unable to resolve reference %s", prefix)); + if (identifierChainBasis.getBasisType() == TABLE) { + RelationType relationType = identifierChainBasis.getRelationType().orElseThrow(); List fields = filterInaccessibleFields(relationType.resolveVisibleFieldsWithRelationPrefix(Optional.of(prefix))); if (fields.isEmpty()) { throw semanticException(COLUMN_NOT_FOUND, allColumns, "SELECT * not allowed from relation that has no columns"); } - boolean local = scope.isLocalScope(identifierChainBasis.get().getScope().orElseThrow()); + boolean local = scope.isLocalScope(identifierChainBasis.getScope().orElseThrow()); analyzeAllColumnsFromTable( fields, allColumns, node, - local ? scope : identifierChainBasis.get().getScope().get(), + local ? scope : identifierChainBasis.getScope().get(), outputExpressionBuilder, selectExpressionBuilder, relationType, @@ -4684,9 +4664,7 @@ private boolean analyzeLimit(Node node, Scope scope) if (node instanceof FetchFirst) { return analyzeLimit((FetchFirst) node, scope); } - else { - return analyzeLimit((Limit) node, scope); - } + return analyzeLimit((Limit) node, scope); } private boolean analyzeLimit(FetchFirst node, Scope scope) @@ -4744,28 +4722,26 @@ private OptionalLong analyzeParameterAsRowCount(Parameter parameter, Scope scope analysis.addCoercion(parameter, BIGINT, false); return OptionalLong.empty(); } - else { - // validate parameter index - analyzeExpression(parameter, scope); - Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); - Object value; - try { - value = evaluateConstantExpression( - providedValue, - BIGINT, - plannerContext, - session, - accessControl, - analysis.getParameters()); - } - catch (VerifyException e) { - throw semanticException(INVALID_ARGUMENTS, parameter, "Non constant parameter value for %s: %s", context, providedValue); - } - if (value == null) { - throw semanticException(INVALID_ARGUMENTS, parameter, "Parameter value provided for %s is NULL: %s", context, providedValue); - } - return OptionalLong.of((long) value); + // validate parameter index + analyzeExpression(parameter, scope); + Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); + Object value; + try { + value = evaluateConstantExpression( + providedValue, + BIGINT, + plannerContext, + session, + accessControl, + analysis.getParameters()); + } + catch (VerifyException e) { + throw semanticException(INVALID_ARGUMENTS, parameter, "Non constant parameter value for %s: %s", context, providedValue); + } + if (value == null) { + throw semanticException(INVALID_ARGUMENTS, parameter, "Parameter value provided for %s is NULL: %s", context, providedValue); } + return OptionalLong.of((long) value); } private Scope createAndAssignScope(Node node, Optional parentScope) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 1f07a8073c4b..633ca43afe0a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -94,7 +94,7 @@ import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.MODULUS_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; @@ -252,7 +252,7 @@ protected Optional translateCall(Call call) return translate(getOnlyElement(call.getArguments())).map(argument -> new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, argument)); } - if (LIKE_PATTERN_FUNCTION_NAME.equals(call.getFunctionName())) { + if (LIKE_FUNCTION_NAME.equals(call.getFunctionName())) { return switch (call.getArguments().size()) { case 2 -> translateLike(call.getArguments().get(0), call.getArguments().get(1), Optional.empty()); case 3 -> translateLike(call.getArguments().get(0), call.getArguments().get(1), Optional.of(call.getArguments().get(2))); @@ -742,12 +742,12 @@ protected Optional visitLikePredicate(LikePredicate node, V Optional pattern = process(node.getPattern()); if (value.isPresent() && pattern.isPresent()) { if (node.getEscape().isEmpty()) { - return Optional.of(new Call(typeOf(node), LIKE_PATTERN_FUNCTION_NAME, List.of(value.get(), pattern.get()))); + return Optional.of(new Call(typeOf(node), LIKE_FUNCTION_NAME, List.of(value.get(), pattern.get()))); } Optional escape = process(node.getEscape().get()); if (escape.isPresent()) { - return Optional.of(new Call(typeOf(node), LIKE_PATTERN_FUNCTION_NAME, List.of(value.get(), pattern.get(), escape.get()))); + return Optional.of(new Call(typeOf(node), LIKE_FUNCTION_NAME, List.of(value.get(), pattern.get(), escape.get()))); } } return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainCoercer.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainCoercer.java index 3118bcfa68f6..f8b85132bc18 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainCoercer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainCoercer.java @@ -211,9 +211,7 @@ private Optional applySaturatedCast(Object originalValue) if (originalComparedToCoerced == 0) { return Optional.of(coercedFloorValue); } - else { - return Optional.empty(); - } + return Optional.empty(); } private int compareOriginalValueToCoerced(ResolvedFunction castToOriginalTypeOperator, MethodHandle comparisonOperator, Object originalValue, Object coercedValue) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java index af2e053aaa30..3c9a6c115014 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/GroupingOperationRewriter.java @@ -51,29 +51,27 @@ public static Expression rewriteGroupingOperation(GroupingOperation expression, if (groupingSets.size() == 1) { return new LongLiteral("0"); } - else { - checkState(groupIdSymbol.isPresent(), "groupId symbol is missing"); + checkState(groupIdSymbol.isPresent(), "groupId symbol is missing"); - RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getFieldId().getRelationId(); + RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getFieldId().getRelationId(); - List columns = expression.getGroupingColumns().stream() - .map(NodeRef::of) - .peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")) - .map(columnReferenceFields::get) - .map(ResolvedField::getFieldId) - .map(fieldId -> translateFieldToInteger(fieldId, relationId)) - .collect(toImmutableList()); + List columns = expression.getGroupingColumns().stream() + .map(NodeRef::of) + .peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")) + .map(columnReferenceFields::get) + .map(ResolvedField::getFieldId) + .map(fieldId -> translateFieldToInteger(fieldId, relationId)) + .collect(toImmutableList()); - List groupingResults = groupingSets.stream() - .map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))) - .map(LongLiteral::new) - .collect(toImmutableList()); + List groupingResults = groupingSets.stream() + .map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))) + .map(LongLiteral::new) + .collect(toImmutableList()); - // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1 - return new SubscriptExpression( - new ArrayConstructor(groupingResults), - new ArithmeticBinaryExpression(ADD, groupIdSymbol.get().toSymbolReference(), new GenericLiteral("BIGINT", "1"))); - } + // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1 + return new SubscriptExpression( + new ArrayConstructor(groupingResults), + new ArithmeticBinaryExpression(ADD, groupIdSymbol.get().toSymbolReference(), new GenericLiteral("BIGINT", "1"))); } private static int translateFieldToInteger(FieldId fieldId, RelationId requiredOriginRelationId) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 193545092f29..897d6917e25b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -69,6 +69,7 @@ import io.trino.operator.MergeProcessorOperator; import io.trino.operator.MergeWriterOperator.MergeWriterOperatorFactory; import io.trino.operator.OperatorFactories; +import io.trino.operator.OperatorFactories.JoinOperatorType; import io.trino.operator.OperatorFactory; import io.trino.operator.OrderByOperator.OrderByOperatorFactory; import io.trino.operator.OutputFactory; @@ -115,7 +116,6 @@ import io.trino.operator.index.IndexLookupSourceFactory; import io.trino.operator.index.IndexSourceOperator; import io.trino.operator.join.HashBuilderOperator.HashBuilderOperatorFactory; -import io.trino.operator.join.JoinBridge; import io.trino.operator.join.JoinBridgeManager; import io.trino.operator.join.JoinOperatorFactory; import io.trino.operator.join.LookupSourceFactory; @@ -1944,19 +1944,17 @@ else if (sourceNode instanceof SampleNode) { return new PhysicalOperation(operatorFactory, outputMappings, context); } - else { - Supplier pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId)); + Supplier pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId)); - OperatorFactory operatorFactory = FilterAndProjectOperator.createOperatorFactory( - context.getNextOperatorId(), - planNodeId, - pageProcessor, - getTypes(projections, expressionTypes), - getFilterAndProjectMinOutputPageSize(session), - getFilterAndProjectMinOutputPageRowCount(session)); + OperatorFactory operatorFactory = FilterAndProjectOperator.createOperatorFactory( + context.getNextOperatorId(), + planNodeId, + pageProcessor, + getTypes(projections, expressionTypes), + getFilterAndProjectMinOutputPageSize(session), + getFilterAndProjectMinOutputPageRowCount(session)); - return new PhysicalOperation(operatorFactory, outputMappings, context, source); - } + return new PhysicalOperation(operatorFactory, outputMappings, context, source); } catch (TrinoException e) { throw e; @@ -2306,15 +2304,14 @@ public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanCo OperatorFactory lookupJoinOperatorFactory; OptionalInt totalOperatorsCount = context.getDriverInstanceCount(); + // We use spilling operator since Non-spilling one does not support index lookup sources lookupJoinOperatorFactory = switch (node.getType()) { - case INNER -> operatorFactories.innerJoin( + case INNER -> operatorFactories.spillingJoin( + JoinOperatorType.innerJoin(false, false), context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, false, - false, - false, - true, // Non-spilling operator does not support index lookup sources probeSource.getTypes(), probeChannels, probeHashChannel, @@ -2322,13 +2319,12 @@ public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanCo totalOperatorsCount, unsupportedPartitioningSpillerFactory(), blockTypeOperators); - case SOURCE_OUTER -> operatorFactories.probeOuterJoin( + case SOURCE_OUTER -> operatorFactories.spillingJoin( + JoinOperatorType.probeOuterJoin(false), context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, false, - false, - true, // Non-spilling operator does not support index lookup sources probeSource.getTypes(), probeChannels, probeHashChannel, @@ -2680,39 +2676,29 @@ private PhysicalOperation createLookupJoin( boolean spillEnabled = isSpillEnabled(session) && node.isSpillable().orElseThrow(() -> new IllegalArgumentException("spillable not yet set")) && !buildOuter; - JoinBridgeManager lookupSourceFactory = - createLookupSourceFactoryManager(node, buildNode, buildSymbols, buildHashSymbol, probeSource, context, spillEnabled, localDynamicFilters); - OperatorFactory operator = createLookupJoin( - node, - probeSource, - probeSymbols, - probeHashSymbol, - lookupSourceFactory, - context, - spillEnabled, - !localDynamicFilters.isEmpty()); - - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); - List outputSymbols = node.getOutputSymbols(); - for (int i = 0; i < outputSymbols.size(); i++) { - Symbol symbol = outputSymbols.get(i); - outputMappings.put(symbol, i); + boolean consumedLocalDynamicFilters = !localDynamicFilters.isEmpty(); + List probeTypes = probeSource.getTypes(); + List probeOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(node.getLeftOutputSymbols(), probeSource.getLayout())); + List probeJoinChannels = ImmutableList.copyOf(getChannelsForSymbols(probeSymbols, probeSource.getLayout())); + OptionalInt probeHashChannel = probeHashSymbol.map(channelGetter(probeSource)) + .map(OptionalInt::of).orElse(OptionalInt.empty()); + OptionalInt totalOperatorsCount = OptionalInt.empty(); + if (spillEnabled) { + totalOperatorsCount = context.getDriverInstanceCount(); + checkState(totalOperatorsCount.isPresent(), "A fixed distribution is required for JOIN when spilling is enabled"); } - return new PhysicalOperation(operator, outputMappings.buildOrThrow(), context, probeSource); - } + // Implementation of hash join operator may only take advantage of output duplicates insensitive joins when: + // 1. Join is of INNER or LEFT type. For right or full joins all matching build rows must be tagged as visited. + // 2. Right (build) output symbols are subset of equi-clauses right symbols. If additional build symbols + // are produced, then skipping build rows could skip some distinct rows. + boolean outputSingleMatch = node.isMaySkipOutputDuplicates() && + node.getCriteria().stream() + .map(JoinNode.EquiJoinClause::getRight) + .collect(toImmutableSet()) + .containsAll(node.getRightOutputSymbols()); - private JoinBridgeManager createLookupSourceFactoryManager( - JoinNode node, - PlanNode buildNode, - List buildSymbols, - Optional buildHashSymbol, - PhysicalOperation probeSource, - LocalExecutionPlanContext context, - boolean spillEnabled, - Set localDynamicFilters) - { LocalExecutionPlanContext buildContext = context.createSubContext(); PhysicalOperation buildSource = buildNode.accept(this, buildContext); @@ -2721,7 +2707,6 @@ private JoinBridgeManager createLookupSourceFactoryManager( OptionalInt buildHashChannel = buildHashSymbol.map(channelGetter(buildSource)) .map(OptionalInt::of).orElse(OptionalInt.empty()); - boolean buildOuter = node.getType() == RIGHT || node.getType() == FULL; int partitionCount = buildContext.getDriverInstanceCount().orElse(1); Map buildLayout = buildSource.getLayout(); @@ -2757,11 +2742,6 @@ private JoinBridgeManager createLookupSourceFactoryManager( .map(buildSource.getTypes()::get) .collect(toImmutableList()); List buildTypes = buildSource.getTypes(); - JoinBridgeManager lookupSourceFactoryManager = new JoinBridgeManager<>( - buildOuter, - createLookupSourceFactory(buildChannels, buildOuter, partitionCount, buildOutputTypes, buildTypes, spillEnabled, session), - buildOutputTypes); - int operatorId = buildContext.getNextOperatorId(); boolean isReplicatedJoin = isBuildSideReplicated(node); Optional localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters, isReplicatedJoin); @@ -2777,12 +2757,29 @@ private JoinBridgeManager createLookupSourceFactoryManager( } int taskConcurrency = getTaskConcurrency(session); - OperatorFactory hashBuilderOperatorFactory; + + // Wait for build side to be collected before local dynamic filters are + // consumed by table scan. This way table scan can filter data more efficiently. + boolean waitForBuild = consumedLocalDynamicFilters; + OperatorFactory operator; if (useSpillingJoinOperator(spillEnabled, session)) { - hashBuilderOperatorFactory = new HashBuilderOperatorFactory( + JoinBridgeManager lookupSourceFactory = new JoinBridgeManager<>( + buildOuter, + new PartitionedLookupSourceFactory( + buildTypes, + buildOutputTypes, + buildChannels.stream() + .map(buildTypes::get) + .collect(toImmutableList()), + partitionCount, + buildOuter, + blockTypeOperators), + buildOutputTypes); + + OperatorFactory hashBuilderOperatorFactory = new HashBuilderOperatorFactory( buildContext.getNextOperatorId(), node.getId(), - (JoinBridgeManager) lookupSourceFactoryManager, + lookupSourceFactory, buildOutputChannels, buildChannels, buildHashChannel, @@ -2798,12 +2795,46 @@ private JoinBridgeManager createLookupSourceFactoryManager( // scale load factor in case partition count (and number of hash build operators) // is reduced (e.g. by plan rule) with respect to default task concurrency taskConcurrency / partitionCount)); + + context.addDriverFactory( + buildContext.isInputDriver(), + false, + new PhysicalOperation(hashBuilderOperatorFactory, buildSource), + buildContext.getDriverInstanceCount()); + + JoinOperatorType joinType = JoinOperatorType.ofJoinNodeType(node.getType(), outputSingleMatch, waitForBuild); + operator = operatorFactories.spillingJoin( + joinType, + context.getNextOperatorId(), + node.getId(), + lookupSourceFactory, + node.getFilter().isPresent(), + probeTypes, + probeJoinChannels, + probeHashChannel, + Optional.of(probeOutputChannels), + totalOperatorsCount, + partitioningSpillerFactory, + blockTypeOperators); } else { - hashBuilderOperatorFactory = new HashBuilderOperator.HashBuilderOperatorFactory( + JoinBridgeManager lookupSourceFactory = new JoinBridgeManager<>( + buildOuter, + new io.trino.operator.join.unspilled.PartitionedLookupSourceFactory( + buildTypes, + buildOutputTypes, + buildChannels.stream() + .map(buildTypes::get) + .collect(toImmutableList()), + partitionCount, + buildOuter, + blockTypeOperators), + buildOutputTypes); + + OperatorFactory hashBuilderOperatorFactory = new HashBuilderOperator.HashBuilderOperatorFactory( buildContext.getNextOperatorId(), node.getId(), - (JoinBridgeManager) lookupSourceFactoryManager, + lookupSourceFactory, buildOutputChannels, buildChannels, buildHashChannel, @@ -2817,14 +2848,35 @@ private JoinBridgeManager createLookupSourceFactoryManager( // scale load factor in case partition count (and number of hash build operators) // is reduced (e.g. by plan rule) with respect to default task concurrency taskConcurrency / partitionCount)); + + context.addDriverFactory( + buildContext.isInputDriver(), + false, + new PhysicalOperation(hashBuilderOperatorFactory, buildSource), + buildContext.getDriverInstanceCount()); + + JoinOperatorType joinType = JoinOperatorType.ofJoinNodeType(node.getType(), outputSingleMatch, waitForBuild); + operator = operatorFactories.join( + joinType, + context.getNextOperatorId(), + node.getId(), + lookupSourceFactory, + node.getFilter().isPresent(), + probeTypes, + probeJoinChannels, + probeHashChannel, + Optional.of(probeOutputChannels), + blockTypeOperators); } - context.addDriverFactory( - buildContext.isInputDriver(), - false, - new PhysicalOperation(hashBuilderOperatorFactory, buildSource), - buildContext.getDriverInstanceCount()); - return lookupSourceFactoryManager; + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + List outputSymbols = node.getOutputSymbols(); + for (int i = 0; i < outputSymbols.size(); i++) { + Symbol symbol = outputSymbols.get(i); + outputMappings.put(symbol, i); + } + + return new PhysicalOperation(operator, outputMappings.buildOrThrow(), context, probeSource); } @Override @@ -2960,99 +3012,6 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } - private OperatorFactory createLookupJoin( - JoinNode node, - PhysicalOperation probeSource, - List probeSymbols, - Optional probeHashSymbol, - JoinBridgeManager lookupSourceFactoryManager, - LocalExecutionPlanContext context, - boolean spillEnabled, - boolean consumedLocalDynamicFilters) - { - List probeTypes = probeSource.getTypes(); - List probeOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(node.getLeftOutputSymbols(), probeSource.getLayout())); - List probeJoinChannels = ImmutableList.copyOf(getChannelsForSymbols(probeSymbols, probeSource.getLayout())); - OptionalInt probeHashChannel = probeHashSymbol.map(channelGetter(probeSource)) - .map(OptionalInt::of).orElse(OptionalInt.empty()); - OptionalInt totalOperatorsCount = OptionalInt.empty(); - if (spillEnabled) { - totalOperatorsCount = context.getDriverInstanceCount(); - checkState(totalOperatorsCount.isPresent(), "A fixed distribution is required for JOIN when spilling is enabled"); - } - - // Implementation of hash join operator may only take advantage of output duplicates insensitive joins when: - // 1. Join is of INNER or LEFT type. For right or full joins all matching build rows must be tagged as visited. - // 2. Right (build) output symbols are subset of equi-clauses right symbols. If additional build symbols - // are produced, then skipping build rows could skip some distinct rows. - boolean outputSingleMatch = node.isMaySkipOutputDuplicates() && - node.getCriteria().stream() - .map(JoinNode.EquiJoinClause::getRight) - .collect(toImmutableSet()) - .containsAll(node.getRightOutputSymbols()); - // Wait for build side to be collected before local dynamic filters are - // consumed by table scan. This way table scan can filter data more efficiently. - boolean waitForBuild = consumedLocalDynamicFilters; - return switch (node.getType()) { - case INNER -> operatorFactories.innerJoin( - context.getNextOperatorId(), - node.getId(), - lookupSourceFactoryManager, - outputSingleMatch, - waitForBuild, - node.getFilter().isPresent(), - useSpillingJoinOperator(spillEnabled, session), - probeTypes, - probeJoinChannels, - probeHashChannel, - Optional.of(probeOutputChannels), - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); - case LEFT -> operatorFactories.probeOuterJoin( - context.getNextOperatorId(), - node.getId(), - lookupSourceFactoryManager, - outputSingleMatch, - node.getFilter().isPresent(), - useSpillingJoinOperator(spillEnabled, session), - probeTypes, - probeJoinChannels, - probeHashChannel, - Optional.of(probeOutputChannels), - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); - case RIGHT -> operatorFactories.lookupOuterJoin( - context.getNextOperatorId(), - node.getId(), - lookupSourceFactoryManager, - waitForBuild, - node.getFilter().isPresent(), - useSpillingJoinOperator(spillEnabled, session), - probeTypes, - probeJoinChannels, - probeHashChannel, - Optional.of(probeOutputChannels), - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); - case FULL -> operatorFactories.fullOuterJoin( - context.getNextOperatorId(), - node.getId(), - lookupSourceFactoryManager, - node.getFilter().isPresent(), - useSpillingJoinOperator(spillEnabled, session), - probeTypes, - probeJoinChannels, - probeHashChannel, - Optional.of(probeOutputChannels), - totalOperatorsCount, - partitioningSpillerFactory, - blockTypeOperators); - }; - } - private Map createJoinSourcesLayout(Map lookupSourceLayout, Map probeSourceLayout) { ImmutableMap.Builder joinSourcesLayout = ImmutableMap.builder(); @@ -3935,61 +3894,26 @@ private OperatorFactory createHashAggregationOperatorFactory( aggregatorFactories, joinCompiler); } - else { - Optional hashChannel = hashSymbol.map(channelGetter(source)); - return new HashAggregationOperatorFactory( - context.getNextOperatorId(), - planNodeId, - groupByTypes, - groupByChannels, - ImmutableList.copyOf(globalGroupingSets), - step, - hasDefaultOutput, - aggregatorFactories, - hashChannel, - groupIdChannel, - expectedGroups, - maxPartialAggregationMemorySize, - spillEnabled, - unspillMemoryLimit, - spillerFactory, - joinCompiler, - blockTypeOperators, - createPartialAggregationController(step, session)); - } - } - } - - private JoinBridge createLookupSourceFactory( - List buildChannels, - boolean buildOuter, - int partitionCount, - ImmutableList buildOutputTypes, - List buildTypes, - boolean spillEnabled, - Session session) - { - if (useSpillingJoinOperator(spillEnabled, session)) { - return new PartitionedLookupSourceFactory( - buildTypes, - buildOutputTypes, - buildChannels.stream() - .map(buildTypes::get) - .collect(toImmutableList()), - partitionCount, - buildOuter, - blockTypeOperators); - } - else { - return new io.trino.operator.join.unspilled.PartitionedLookupSourceFactory( - buildTypes, - buildOutputTypes, - buildChannels.stream() - .map(buildTypes::get) - .collect(toImmutableList()), - partitionCount, - buildOuter, - blockTypeOperators); + Optional hashChannel = hashSymbol.map(channelGetter(source)); + return new HashAggregationOperatorFactory( + context.getNextOperatorId(), + planNodeId, + groupByTypes, + groupByChannels, + ImmutableList.copyOf(globalGroupingSets), + step, + hasDefaultOutput, + aggregatorFactories, + hashChannel, + groupIdChannel, + expectedGroups, + maxPartialAggregationMemorySize, + spillEnabled, + unspillMemoryLimit, + spillerFactory, + joinCompiler, + blockTypeOperators, + createPartialAggregationController(step, session)); } } @@ -4073,10 +3997,10 @@ private static TableFinisher createTableFinisher(Session session, TableFinishNod if (target instanceof CreateTarget) { return metadata.finishCreateTable(session, ((CreateTarget) target).getHandle(), fragments, statistics); } - else if (target instanceof InsertTarget) { + if (target instanceof InsertTarget) { return metadata.finishInsert(session, ((InsertTarget) target).getHandle(), fragments, statistics); } - else if (target instanceof TableWriterNode.RefreshMaterializedViewTarget) { + if (target instanceof TableWriterNode.RefreshMaterializedViewTarget) { TableWriterNode.RefreshMaterializedViewTarget refreshTarget = (TableWriterNode.RefreshMaterializedViewTarget) target; return metadata.finishRefreshMaterializedView( session, @@ -4086,27 +4010,25 @@ else if (target instanceof TableWriterNode.RefreshMaterializedViewTarget) { statistics, refreshTarget.getSourceTableHandles()); } - else if (target instanceof DeleteTarget) { + if (target instanceof DeleteTarget) { metadata.finishDelete(session, ((DeleteTarget) target).getHandleOrElseThrow(), fragments); return Optional.empty(); } - else if (target instanceof UpdateTarget) { + if (target instanceof UpdateTarget) { metadata.finishUpdate(session, ((UpdateTarget) target).getHandleOrElseThrow(), fragments); return Optional.empty(); } - else if (target instanceof TableExecuteTarget) { + if (target instanceof TableExecuteTarget) { TableExecuteHandle tableExecuteHandle = ((TableExecuteTarget) target).getExecuteHandle(); metadata.finishTableExecute(session, tableExecuteHandle, fragments, tableExecuteContext.getSplitsInfo()); return Optional.empty(); } - else if (target instanceof MergeTarget mergeTarget) { + if (target instanceof MergeTarget mergeTarget) { MergeHandle mergeHandle = mergeTarget.getMergeHandle().orElseThrow(() -> new IllegalArgumentException("mergeHandle not present")); metadata.finishMerge(session, mergeHandle, fragments, statistics); return Optional.empty(); } - else { - throw new AssertionError("Unhandled target type: " + target.getClass().getName()); - } + throw new AssertionError("Unhandled target type: " + target.getClass().getName()); }; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java index ca686bb8de09..78977f68e400 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/Partitioning.java @@ -141,26 +141,20 @@ private static boolean isPartitionedWith( Set mappedColumns = leftToRightMappings.apply(leftArgument.getColumn()); return mappedColumns.contains(rightArgument.getColumn()); } - else { - // variable == constant - // Normally, this would be a false condition, but if we happen to have an external - // mapping from the symbol to a constant value and that constant value matches the - // right value, then we are co-partitioned. - Optional leftConstant = leftConstantMapping.apply(leftArgument.getColumn()); - return leftConstant.isPresent() && leftConstant.get().equals(rightArgument.getConstant()); - } + // variable == constant + // Normally, this would be a false condition, but if we happen to have an external + // mapping from the symbol to a constant value and that constant value matches the + // right value, then we are co-partitioned. + Optional leftConstant = leftConstantMapping.apply(leftArgument.getColumn()); + return leftConstant.isPresent() && leftConstant.get().equals(rightArgument.getConstant()); } - else { - if (rightArgument.isConstant()) { - // constant == constant - return leftArgument.getConstant().equals(rightArgument.getConstant()); - } - else { - // constant == variable - Optional rightConstant = rightConstantMapping.apply(rightArgument.getColumn()); - return rightConstant.isPresent() && rightConstant.get().equals(leftArgument.getConstant()); - } + if (rightArgument.isConstant()) { + // constant == constant + return leftArgument.getConstant().equals(rightArgument.getConstant()); } + // constant == variable + Optional rightConstant = rightConstantMapping.apply(rightArgument.getColumn()); + return rightConstant.isPresent() && rightConstant.get().equals(leftArgument.getConstant()); } public boolean isPartitionedOn(Collection columns, Set knownConstants) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ScopeAware.java b/core/trino-main/src/main/java/io/trino/sql/planner/ScopeAware.java index d4ffef3c6f4f..ae3919ab1247 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ScopeAware.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ScopeAware.java @@ -138,13 +138,13 @@ private Boolean scopeAwareComparison(Node left, Node right) } // For references that come from the current query scope or an outer scope of the current // expression, compare by resolved field - else if (!leftFieldInSubqueryScope && !rightFieldInSubqueryScope) { + if (!leftFieldInSubqueryScope && !rightFieldInSubqueryScope) { return leftField.getFieldId().equals(rightField.getFieldId()); } // References come from different scopes return false; } - else if (leftExpression instanceof Identifier && rightExpression instanceof Identifier) { + if (leftExpression instanceof Identifier && rightExpression instanceof Identifier) { return treeEqual(leftExpression, rightExpression, CanonicalizationAware::canonicalizationAwareComparison); } } @@ -170,10 +170,10 @@ private OptionalInt scopeAwareHash(Node node) return OptionalInt.of(field.getFieldId().hashCode()); } - else if (expression instanceof Identifier) { + if (expression instanceof Identifier) { return OptionalInt.of(treeHash(expression, CanonicalizationAware::canonicalizationAwareHash)); } - else if (node.getChildren().isEmpty()) { + if (node.getChildren().isEmpty()) { // Calculate shallow hash since node doesn't have any children return OptionalInt.of(expression.hashCode()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java index a11459c95883..38d7f6472bd5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java @@ -101,15 +101,13 @@ public Result apply(TableScanNode scanNode, Captures captures, Context context) return Result.empty(); } - Optional destinationTableHandle = plannerContext.getMetadata().getTableHandle( - context.getSession(), - convertFromSchemaTableName(destinationTable.getCatalogName()).apply(destinationTable.getSchemaTableName())); - if (destinationTableHandle.isEmpty()) { - throw new TrinoException(TABLE_NOT_FOUND, format("Destination table %s from table scan redirection not found", destinationTable)); - } + TableHandle destinationTableHandle = plannerContext.getMetadata().getTableHandle( + context.getSession(), + convertFromSchemaTableName(destinationTable.getCatalogName()).apply(destinationTable.getSchemaTableName())) + .orElseThrow(() -> new TrinoException(TABLE_NOT_FOUND, format("Destination table %s from table scan redirection not found", destinationTable))); Map columnMapping = tableScanRedirectApplicationResult.get().getDestinationColumns(); - Map destinationColumnHandles = plannerContext.getMetadata().getColumnHandles(context.getSession(), destinationTableHandle.get()); + Map destinationColumnHandles = plannerContext.getMetadata().getColumnHandles(context.getSession(), destinationTableHandle); ImmutableMap.Builder casts = ImmutableMap.builder(); ImmutableMap.Builder newAssignmentsBuilder = ImmutableMap.builder(); for (Map.Entry assignment : scanNode.getAssignments().entrySet()) { @@ -124,7 +122,7 @@ public Result apply(TableScanNode scanNode, Captures captures, Context context) // insert ts if redirected types don't match source types Type sourceType = context.getSymbolAllocator().getTypes().get(assignment.getKey()); - Type redirectedType = plannerContext.getMetadata().getColumnMetadata(context.getSession(), destinationTableHandle.get(), destinationColumnHandle).getType(); + Type redirectedType = plannerContext.getMetadata().getColumnMetadata(context.getSession(), destinationTableHandle, destinationColumnHandle).getType(); if (!sourceType.equals(redirectedType)) { Symbol redirectedSymbol = context.getSymbolAllocator().newSymbol(destinationColumn, redirectedType); Cast cast = getCast( @@ -153,7 +151,7 @@ public Result apply(TableScanNode scanNode, Captures captures, Context context) casts.buildOrThrow(), new TableScanNode( scanNode.getId(), - destinationTableHandle.get(), + destinationTableHandle, ImmutableList.copyOf(newAssignments.keySet()), newAssignments, TupleDomain.all(), @@ -185,7 +183,7 @@ public Result apply(TableScanNode scanNode, Captures captures, Context context) } // insert casts if redirected types don't match domain types - Type redirectedType = plannerContext.getMetadata().getColumnMetadata(context.getSession(), destinationTableHandle.get(), destinationColumnHandle).getType(); + Type redirectedType = plannerContext.getMetadata().getColumnMetadata(context.getSession(), destinationTableHandle, destinationColumnHandle).getType(); if (!domainType.equals(redirectedType)) { Symbol redirectedSymbol = context.getSymbolAllocator().newSymbol(destinationColumn, redirectedType); Cast cast = getCast( @@ -210,7 +208,7 @@ public Result apply(TableScanNode scanNode, Captures captures, Context context) Map newAssignments = newAssignmentsBuilder.buildOrThrow(); TableScanNode newScanNode = new TableScanNode( scanNode.getId(), - destinationTableHandle.get(), + destinationTableHandle, ImmutableList.copyOf(newAssignments.keySet()), newAssignments, TupleDomain.all(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java index dbc66894c3cf..0e4c1dbda907 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java @@ -247,7 +247,7 @@ public Expression rewriteFunctionCall(FunctionCall node, Void context, Expressio || argumentType instanceof TimestampWithTimeZoneType || argumentType instanceof VarcharType) { // prefer `CAST(x as DATE)` to `date(x)` - return new Cast(argument, toSqlType(DateType.DATE)); + return new Cast(treeRewriter.rewrite(argument, context), toSqlType(DateType.DATE)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherAndMergeWindows.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherAndMergeWindows.java index 36e1ba5561a1..c72e921438c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -221,9 +221,7 @@ protected Optional manipulateAdjacentWindowNodes(WindowNode parent, Wi restrictOutputs(context.getIdAllocator(), transposedWindows, ImmutableSet.copyOf(parent.getOutputSymbols())) .orElse(transposedWindows)); } - else { - return Optional.empty(); - } + return Optional.empty(); } private static int compare(WindowNode o1, WindowNode o2) @@ -291,11 +289,9 @@ private static int compareOrderBy(WindowNode o1, WindowNode o2) if (orderByComparison != 0) { return orderByComparison; } - else { - int sortOrderComparison = o1OrderingScheme.getOrdering(symbol1).compareTo(o2OrderingScheme.getOrdering(symbol2)); - if (sortOrderComparison != 0) { - return sortOrderComparison; - } + int sortOrderComparison = o1OrderingScheme.getOrdering(symbol1).compareTo(o2OrderingScheme.getOrdering(symbol2)); + if (sortOrderComparison != 0) { + return sortOrderComparison; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeLimitWithTopN.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeLimitWithTopN.java index 722206216b2f..7a9386fae224 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeLimitWithTopN.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeLimitWithTopN.java @@ -72,9 +72,7 @@ public Result apply(LimitNode parent, Captures captures, Context context) if (parent.getCount() < child.getCount()) { return Result.empty(); } - else { - return Result.ofPlanNode(child); - } + return Result.ofPlanNode(child); } return Result.ofPlanNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java index 7f1f2eb545dc..ba8e1306ec83 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java @@ -58,9 +58,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) if (rewriter.isRewritten()) { return Result.ofPlanNode(replaceChildren(node, newSources)); } - else { - return Result.empty(); - } + return Result.empty(); } private static boolean isDistinctOperator(AggregationNode node) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java index cfce0e542cfc..3d99f735cf72 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java @@ -97,9 +97,7 @@ public Result apply(FilterNode node, Captures captures, Context context) if (needRewriteSource) { return Result.ofPlanNode(new FilterNode(node.getId(), source, node.getPredicate())); } - else { - return Result.empty(); - } + return Result.empty(); } TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rowNumberSymbol)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java index a6c9361c0dba..3f761f28b0fa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.java @@ -82,7 +82,7 @@ public Result apply(JoinNode node, Captures captures, Context context) context.getIdAllocator(), context.getSymbolAllocator())); } - else if (!leftSourceEmpty && rightSourceEmpty) { + if (!leftSourceEmpty && rightSourceEmpty) { yield Result.ofPlanNode(appendNulls( node.getLeft(), node.getLeftOutputSymbols(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 47ce4c17308a..a3f592f921a7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -370,9 +370,7 @@ protected Optional visitPlan(PlanNode node, PlanNode reference) if (isCorrelatedRecursively(node)) { return Optional.empty(); } - else { - return Optional.of(new Decorrelated(ImmutableList.of(), reference)); - } + return Optional.of(new Decorrelated(ImmutableList.of(), reference)); } private boolean isCorrelatedRecursively(PlanNode node) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java index e52d03885534..51c0d031edd2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -227,10 +227,8 @@ private Expression unwrapCast(ComparisonExpression expression) if (!typeHasNaN(sourceType)) { return TRUE_LITERAL; } - else { - // NaN on the right of comparison will be cast to source type later - break; - } + // NaN on the right of comparison will be cast to source type later + break; default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } @@ -404,13 +402,11 @@ private boolean hasInjectiveImplicitCoercion(Type source, Type target, Object va Double.isNaN(doubleValue) || (doubleValue > -1L << 53 && doubleValue < 1L << 53); // in (-2^53, 2^53), bigint follows an injective implicit coercion w.r.t double } - else { - float realValue = intBitsToFloat(toIntExact((long) value)); - return (source.equals(BIGINT) && (realValue > Long.MAX_VALUE || realValue < Long.MIN_VALUE)) || - (source.equals(INTEGER) && (realValue > Integer.MAX_VALUE || realValue < Integer.MIN_VALUE)) || - Float.isNaN(realValue) || - (realValue > -1L << 23 && realValue < 1L << 23); // in (-2^23, 2^23), bigint (and integer) follows an injective implicit coercion w.r.t real - } + float realValue = intBitsToFloat(toIntExact((long) value)); + return (source.equals(BIGINT) && (realValue > Long.MAX_VALUE || realValue < Long.MIN_VALUE)) || + (source.equals(INTEGER) && (realValue > Integer.MAX_VALUE || realValue < Integer.MIN_VALUE)) || + Float.isNaN(realValue) || + (realValue > -1L << 23 && realValue < 1L << 23); // in (-2^23, 2^23), bigint (and integer) follows an injective implicit coercion w.r.t real } if (source instanceof DecimalType) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ActualProperties.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ActualProperties.java index 033ea5079075..af1066506313 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ActualProperties.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ActualProperties.java @@ -109,9 +109,7 @@ public boolean isStreamPartitionedOn(Collection columns, boolean nullsAn if (exactly) { return global.isStreamPartitionedOnExactly(columns, constants.keySet(), nullsAndAnyReplicated); } - else { - return global.isStreamPartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); - } + return global.isStreamPartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); } public boolean isNodePartitionedOn(Collection columns, boolean exactly) @@ -124,9 +122,7 @@ public boolean isNodePartitionedOn(Collection columns, boolean nullsAndA if (exactly) { return global.isNodePartitionedOnExactly(columns, constants.keySet(), nullsAndAnyReplicated); } - else { - return global.isNodePartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); - } + return global.isNodePartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); } public boolean isCompatibleTablePartitioningWith(Partitioning partitioning, boolean nullsAndAnyReplicated, Metadata metadata, Session session) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index a9c88d98bc0a..d09790704647 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -780,9 +780,7 @@ && isNodePartitionedOn(left.getProperties(), leftSymbols) && !left.getProperties return planReplicatedJoin(node, left); } - else { - return planPartitionedJoin(node, leftSymbols, rightSymbols); - } + return planPartitionedJoin(node, leftSymbols, rightSymbols); } private PlanWithProperties planPartitionedJoin(JoinNode node, List leftSymbols, List rightSymbols) @@ -1272,19 +1270,17 @@ private PlanWithProperties arbitraryDistributeUnion( // instead of "arbitraryPartition". return new PlanWithProperties(node.replaceChildren(partitionedChildren)); } - else { - // Trino currently cannot execute stage that has multiple table scans, so in that case - // we have to insert REMOTE exchange with FIXED_ARBITRARY_DISTRIBUTION instead of local exchange - return new PlanWithProperties( - new ExchangeNode( - idAllocator.getNextId(), - REPARTITION, - REMOTE, - new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), node.getOutputSymbols()), - partitionedChildren, - partitionedOutputLayouts, - Optional.empty())); - } + // Trino currently cannot execute stage that has multiple table scans, so in that case + // we have to insert REMOTE exchange with FIXED_ARBITRARY_DISTRIBUTION instead of local exchange + return new PlanWithProperties( + new ExchangeNode( + idAllocator.getNextId(), + REPARTITION, + REMOTE, + new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), node.getOutputSymbols()), + partitionedChildren, + partitionedOutputLayouts, + Optional.empty())); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java index 7b82987c7b3e..7f7e3cfd07e5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -162,9 +162,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont // partition key does not have a single value, so bail out to be safe return context.defaultRewrite(node); } - else { - rowBuilder.add(literalEncoder.toExpression(session, value.getValue(), type)); - } + rowBuilder.add(literalEncoder.toExpression(session, value.getValue(), type)); } rowsBuilder.add(new Row(rowBuilder.build())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java index 7f0ac26a3d8e..a04de1f5f41f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java @@ -201,12 +201,10 @@ private PlanNode removeFirstRecursive(PlanNode node) if (sources.isEmpty()) { return node; } - else if (sources.size() == 1) { + if (sources.size() == 1) { return replaceChildren(node, ImmutableList.of(removeFirstRecursive(sources.get(0)))); } - else { - throw new IllegalArgumentException("Unable to remove first node when a node has multiple children, use removeAll instead"); - } + throw new IllegalArgumentException("Unable to remove first node when a node has multiple children, use removeAll instead"); } return node; } @@ -248,12 +246,10 @@ private PlanNode replaceFirstRecursive(PlanNode node, PlanNode nodeToReplace) if (sources.isEmpty()) { return node; } - else if (sources.size() == 1) { + if (sources.size() == 1) { return replaceChildren(node, ImmutableList.of(replaceFirstRecursive(node, sources.get(0)))); } - else { - throw new IllegalArgumentException("Unable to replace first node when a node has multiple children, use replaceAll instead"); - } + throw new IllegalArgumentException("Unable to replace first node when a node has multiple children, use replaceAll instead"); } public boolean matches() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 9990cfe494f7..353506f8da14 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -1154,24 +1154,22 @@ private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheri node.getDynamicFilters(), node.getReorderJoinStatsAndCost()); } - else { - return new JoinNode( - node.getId(), - canConvertToLeftJoin ? LEFT : RIGHT, - node.getLeft(), - node.getRight(), - node.getCriteria(), - node.getLeftOutputSymbols(), - node.getRightOutputSymbols(), - node.isMaySkipOutputDuplicates(), - node.getFilter(), - node.getLeftHashSymbol(), - node.getRightHashSymbol(), - node.getDistributionType(), - node.isSpillable(), - node.getDynamicFilters(), - node.getReorderJoinStatsAndCost()); - } + return new JoinNode( + node.getId(), + canConvertToLeftJoin ? LEFT : RIGHT, + node.getLeft(), + node.getRight(), + node.getCriteria(), + node.getLeftOutputSymbols(), + node.getRightOutputSymbols(), + node.isMaySkipOutputDuplicates(), + node.getFilter(), + node.getLeftHashSymbol(), + node.getRightHashSymbol(), + node.getDistributionType(), + node.isSpillable(), + node.getDynamicFilters(), + node.getReorderJoinStatsAndCost()); } if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 1b831b18a403..444f03845499 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -964,7 +964,7 @@ public static Optional filterOrRewrite(Collection columns, Colle if (equality.getLeft().equals(column) && columns.contains(equality.getRight())) { return Optional.of(equality.getRight()); } - else if (equality.getRight().equals(column) && columns.contains(equality.getLeft())) { + if (equality.getRight().equals(column) && columns.contains(equality.getLeft())) { return Optional.of(equality.getLeft()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/QueryCardinalityUtil.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/QueryCardinalityUtil.java index 6a06d03f1750..2159e1c8d18d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/QueryCardinalityUtil.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/QueryCardinalityUtil.java @@ -183,9 +183,7 @@ public Range visitOffset(OffsetNode node, Void context) if (sourceCardinalityRange.hasUpperBound()) { return Range.closed(lower, max(sourceCardinalityRange.upperEndpoint() - node.getCount(), 0L)); } - else { - return Range.atLeast(lower); - } + return Range.atLeast(lower); } @Override @@ -197,9 +195,7 @@ public Range visitLimit(LimitNode node, Void context) if (sourceCardinalityRange.hasUpperBound()) { return Range.closed(lower, sourceCardinalityRange.upperEndpoint()); } - else { - return Range.atLeast(lower); - } + return Range.atLeast(lower); } return applyLimit(node.getSource(), node.getCount()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPreferredProperties.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPreferredProperties.java index 4d81bcd8a9ec..3f0083d109e9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPreferredProperties.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPreferredProperties.java @@ -183,10 +183,10 @@ public boolean isSatisfiedBy(StreamProperties actualProperties) if (distribution.get() == SINGLE && actualDistribution != SINGLE) { return false; } - else if (distribution.get() == FIXED && actualDistribution != FIXED) { + if (distribution.get() == FIXED && actualDistribution != FIXED) { return false; } - else if (distribution.get() == MULTIPLE && actualDistribution != FIXED && actualDistribution != MULTIPLE) { + if (distribution.get() == MULTIPLE && actualDistribution != FIXED && actualDistribution != MULTIPLE) { return false; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java index 1a1d45bdc816..0a9ec9ecad56 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/CounterBasedAnonymizer.java @@ -225,22 +225,22 @@ public String anonymize(WriterTarget target) if (target instanceof CreateTarget) { return anonymize((CreateTarget) target); } - else if (target instanceof InsertTarget) { + if (target instanceof InsertTarget) { return anonymize((InsertTarget) target); } - else if (target instanceof MergeTarget) { + if (target instanceof MergeTarget) { return anonymize((MergeTarget) target); } - else if (target instanceof RefreshMaterializedViewTarget) { + if (target instanceof RefreshMaterializedViewTarget) { return anonymize((RefreshMaterializedViewTarget) target); } - else if (target instanceof DeleteTarget) { + if (target instanceof DeleteTarget) { return anonymize((DeleteTarget) target); } - else if (target instanceof UpdateTarget) { + if (target instanceof UpdateTarget) { return anonymize((UpdateTarget) target); } - else if (target instanceof TableExecuteTarget) { + if (target instanceof TableExecuteTarget) { return anonymize((TableExecuteTarget) target); } throw new UnsupportedOperationException("Anonymization is not supported for WriterTarget type: " + target.getClass().getSimpleName()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/IrRowPatternFlattener.java b/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/IrRowPatternFlattener.java index 4778f2be9822..9d77c2383cf1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/IrRowPatternFlattener.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/rowpattern/IrRowPatternFlattener.java @@ -118,9 +118,7 @@ protected IrRowPattern visitIrAlternation(IrAlternation node, Boolean inExclusio if (child instanceof IrAlternation) { return ((IrAlternation) child).getPatterns().stream(); } - else { - return Stream.of(child); - } + return Stream.of(child); }) .collect(toImmutableList()); @@ -174,9 +172,7 @@ protected IrRowPattern visitIrConcatenation(IrConcatenation node, Boolean inExcl if (child instanceof IrConcatenation) { return ((IrConcatenation) child).getPatterns().stream(); } - else { - return Stream.of(child); - } + return Stream.of(child); }) .filter(child -> !(child instanceof IrEmpty)) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java index 2deb772241af..25d5de77e612 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java @@ -130,9 +130,7 @@ public RowExpression visitSpecialForm(SpecialForm specialForm, Void context) return specialForm.getArguments().get(1).accept(this, context); } // FALSE and NULL - else { - return specialForm.getArguments().get(2).accept(this, context); - } + return specialForm.getArguments().get(2).accept(this, context); } List arguments = specialForm.getArguments().stream() .map(argument -> argument.accept(this, null)) diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java index 92123f6349a0..b5f05dbfa5ee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java @@ -306,8 +306,7 @@ protected Node visitShowGrants(ShowGrants showGrants, Void context) QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, showGrants, tableName.get()); if (!metadata.isView(session, qualifiedTableName)) { RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, qualifiedTableName); - Optional tableHandle = redirection.getTableHandle(); - if (tableHandle.isEmpty()) { + if (redirection.getTableHandle().isEmpty()) { throw semanticException(TABLE_NOT_FOUND, showGrants, "Table '%s' does not exist", tableName); } if (redirection.getRedirectedTableName().isPresent()) { @@ -374,13 +373,11 @@ protected Node visitShowRoles(ShowRoles node, Void context) .collect(toList()); return singleColumnValues(rows, "Role"); } - else { - accessControl.checkCanShowRoles(session.toSecurityContext(), catalog); - List rows = metadata.listRoles(session, catalog).stream() - .map(role -> row(new StringLiteral(role))) - .collect(toList()); - return singleColumnValues(rows, "Role"); - } + accessControl.checkCanShowRoles(session.toSecurityContext(), catalog); + List rows = metadata.listRoles(session, catalog).stream() + .map(role -> row(new StringLiteral(role))) + .collect(toList()); + return singleColumnValues(rows, "Role"); } @Override @@ -654,16 +651,14 @@ protected Node visitShowCreate(ShowCreate node, Void context) } RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, objectName); - Optional tableHandle = redirection.getTableHandle(); - if (tableHandle.isEmpty()) { - throw semanticException(TABLE_NOT_FOUND, node, "Table '%s' does not exist", objectName); - } + TableHandle tableHandle = redirection.getTableHandle() + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, node, "Table '%s' does not exist", objectName)); QualifiedObjectName targetTableName = redirection.getRedirectedTableName().orElse(objectName); accessControl.checkCanShowCreateTable(session.toSecurityContext(), targetTableName); - ConnectorTableMetadata connectorTableMetadata = metadata.getTableMetadata(session, tableHandle.get()).getMetadata(); + ConnectorTableMetadata connectorTableMetadata = metadata.getTableMetadata(session, tableHandle).getMetadata(); - Collection> allColumnProperties = columnPropertyManager.getAllProperties(tableHandle.get().getCatalogHandle()); + Collection> allColumnProperties = columnPropertyManager.getAllProperties(tableHandle.getCatalogHandle()); List columns = connectorTableMetadata.getColumns().stream() .filter(column -> !column.isHidden()) @@ -679,7 +674,7 @@ protected Node visitShowCreate(ShowCreate node, Void context) .collect(toImmutableList()); Map properties = connectorTableMetadata.getProperties(); - Collection> allTableProperties = tablePropertyManager.getAllProperties(tableHandle.get().getCatalogHandle()); + Collection> allTableProperties = tablePropertyManager.getAllProperties(tableHandle.getCatalogHandle()); List propertyNodes = buildProperties(targetTableName, Optional.empty(), INVALID_TABLE_PROPERTY, properties, allTableProperties); CreateTable createTable = new CreateTable( diff --git a/core/trino-main/src/main/java/io/trino/testing/AllowAllAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/AllowAllAccessControlManager.java index 213a1e3dd1dd..fa6f9bf0e857 100644 --- a/core/trino-main/src/main/java/io/trino/testing/AllowAllAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/AllowAllAccessControlManager.java @@ -186,6 +186,9 @@ public void checkCanSetMaterializedViewProperties(SecurityContext context, Quali @Override public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption) {} + @Override + public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption) {} + @Override public void checkCanGrantSchemaPrivilege(SecurityContext context, Privilege privilege, CatalogSchemaName schemaName, TrinoPrincipal grantee, boolean grantOption) {} diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index 7189e5d7ada9..9a3b7d38e1b4 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -23,6 +23,7 @@ import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.FunctionKind; import io.trino.spi.security.Identity; import io.trino.spi.security.ViewExpression; import io.trino.spi.type.Type; @@ -63,6 +64,7 @@ import static io.trino.spi.security.AccessDeniedException.denyDropView; import static io.trino.spi.security.AccessDeniedException.denyExecuteFunction; import static io.trino.spi.security.AccessDeniedException.denyExecuteQuery; +import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; import static io.trino.spi.security.AccessDeniedException.denyGrantExecuteFunctionPrivilege; import static io.trino.spi.security.AccessDeniedException.denyImpersonateUser; import static io.trino.spi.security.AccessDeniedException.denyInsertTable; @@ -101,6 +103,7 @@ import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_VIEW; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE_FUNCTION; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE_QUERY; +import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE_TABLE_PROCEDURE; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.GRANT_EXECUTE_FUNCTION; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.IMPERSONATE_USER; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE; @@ -606,6 +609,17 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, Strin } } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption) + { + if (shouldDenyPrivilege(context.getIdentity().getUser(), functionName.toString(), GRANT_EXECUTE_FUNCTION)) { + denyGrantExecuteFunctionPrivilege(functionName.toString(), context.getIdentity(), grantee); + } + if (denyPrivileges.isEmpty()) { + super.checkCanGrantExecuteFunctionPrivilege(context, functionName.toString(), grantee, grantOption); + } + } + @Override public void checkCanShowColumns(SecurityContext context, CatalogSchemaTableName table) { @@ -670,6 +684,17 @@ public void checkCanExecuteFunction(SecurityContext context, String functionName } } + @Override + public void checkCanExecuteTableProcedure(SecurityContext context, QualifiedObjectName table, String procedure) + { + if (shouldDenyPrivilege(context.getIdentity().getUser(), table + "." + procedure, EXECUTE_TABLE_PROCEDURE)) { + denyExecuteTableProcedure(table.toString(), procedure); + } + if (denyPrivileges.isEmpty()) { + super.checkCanExecuteTableProcedure(context, table, procedure); + } + } + @Override public List getRowFilters(SecurityContext context, QualifiedObjectName tableName) { @@ -709,7 +734,7 @@ public enum TestingPrivilegeType { SET_USER, IMPERSONATE_USER, EXECUTE_QUERY, VIEW_QUERY, KILL_QUERY, - EXECUTE_FUNCTION, + EXECUTE_FUNCTION, EXECUTE_TABLE_PROCEDURE, CREATE_SCHEMA, DROP_SCHEMA, RENAME_SCHEMA, SHOW_CREATE_TABLE, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, COMMENT_TABLE, COMMENT_VIEW, COMMENT_COLUMN, INSERT_TABLE, DELETE_TABLE, MERGE_TABLE, UPDATE_TABLE, TRUNCATE_TABLE, SET_TABLE_PROPERTIES, SHOW_COLUMNS, ADD_COLUMN, DROP_COLUMN, RENAME_COLUMN, SELECT_COLUMN, diff --git a/core/trino-main/src/main/java/io/trino/transaction/InMemoryTransactionManager.java b/core/trino-main/src/main/java/io/trino/transaction/InMemoryTransactionManager.java index 8df986833457..6d359a64b4c8 100644 --- a/core/trino-main/src/main/java/io/trino/transaction/InMemoryTransactionManager.java +++ b/core/trino-main/src/main/java/io/trino/transaction/InMemoryTransactionManager.java @@ -375,9 +375,7 @@ public void checkOpenTransaction() // Should not happen normally throw new IllegalStateException("Current transaction already committed"); } - else { - throw new TrinoException(TRANSACTION_ALREADY_ABORTED, "Current transaction is aborted, commands ignored until end of transaction block"); - } + throw new TrinoException(TRANSACTION_ALREADY_ABORTED, "Current transaction is aborted, commands ignored until end of transaction block"); } } diff --git a/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java b/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java index 23826d98b37c..f3c9e1ecc59e 100644 --- a/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java +++ b/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java @@ -54,9 +54,7 @@ public static boolean likeVarchar(@SqlType("varchar(x)") Slice value, @SqlType(L if (value.hasByteArray()) { return matcher.match(value.byteArray(), value.byteArrayOffset(), value.length()); } - else { - return matcher.match(value.getBytes(), 0, value.length()); - } + return matcher.match(value.getBytes(), 0, value.length()); } @ScalarFunction(value = LIKE_PATTERN_FUNCTION_NAME, hidden = true) diff --git a/core/trino-main/src/main/java/io/trino/type/TypeUtils.java b/core/trino-main/src/main/java/io/trino/type/TypeUtils.java index 2d017e8cc54b..3f2f330c68a1 100644 --- a/core/trino-main/src/main/java/io/trino/type/TypeUtils.java +++ b/core/trino-main/src/main/java/io/trino/type/TypeUtils.java @@ -113,9 +113,7 @@ private static String getRowDisplayLabelForLegacyClients(RowType type) if (field.getName().isPresent()) { return field.getName().get() + ' ' + typeDisplayName; } - else { - return typeDisplayName; - } + return typeDisplayName; }) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/util/DateTimeZoneIndex.java b/core/trino-main/src/main/java/io/trino/util/DateTimeZoneIndex.java index 62902c4213f8..3a014450dc91 100644 --- a/core/trino-main/src/main/java/io/trino/util/DateTimeZoneIndex.java +++ b/core/trino-main/src/main/java/io/trino/util/DateTimeZoneIndex.java @@ -97,8 +97,6 @@ public static int extractZoneOffsetMinutes(long epochMillis, short zoneKey) if (FIXED_ZONE_OFFSET[zoneKey] == VARIABLE_ZONE) { return DATE_TIME_ZONES[zoneKey].getOffset(epochMillis) / 60_000; } - else { - return FIXED_ZONE_OFFSET[zoneKey]; - } + return FIXED_ZONE_OFFSET[zoneKey]; } } diff --git a/core/trino-main/src/main/java/io/trino/util/DisjointSet.java b/core/trino-main/src/main/java/io/trino/util/DisjointSet.java index 8eb5343d9b39..8100185b95bb 100644 --- a/core/trino-main/src/main/java/io/trino/util/DisjointSet.java +++ b/core/trino-main/src/main/java/io/trino/util/DisjointSet.java @@ -120,11 +120,9 @@ private T findInternal(T element) if (value.getParent() == null) { return element; } - else { - T root = findInternal(value.getParent()); - value.setParent(root); - return root; - } + T root = findInternal(value.getParent()); + value.setParent(root); + return root; } public Collection> getEquivalentClasses() diff --git a/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java b/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java index 2963e9ab19d9..e6373cc62547 100644 --- a/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java +++ b/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java @@ -65,12 +65,10 @@ public static Set toFastutilHashSet(Set set, Type type, MethodHandle hashC if (javaElementType == boolean.class) { return new BooleanOpenHashSet((Collection) set, 0.25f); } - else if (!type.getJavaType().isPrimitive()) { + if (!type.getJavaType().isPrimitive()) { return new ObjectOpenCustomHashSet<>(set, 0.25f, new ObjectStrategy(hashCodeHandle, equalsHandle, type)); } - else { - throw new UnsupportedOperationException("Unsupported native type in set: " + type.getJavaType() + " with type " + type.getTypeSignature()); - } + throw new UnsupportedOperationException("Unsupported native type in set: " + type.getJavaType() + " with type " + type.getTypeSignature()); } public static boolean in(boolean booleanValue, BooleanOpenHashSet set) diff --git a/core/trino-main/src/main/java/io/trino/util/JsonUtil.java b/core/trino-main/src/main/java/io/trino/util/JsonUtil.java index 50dcb5561e27..1b14f5de8caf 100644 --- a/core/trino-main/src/main/java/io/trino/util/JsonUtil.java +++ b/core/trino-main/src/main/java/io/trino/util/JsonUtil.java @@ -133,9 +133,7 @@ public static String truncateIfNecessaryForErrorMessage(Slice json) if (json.length() <= MAX_JSON_LENGTH_IN_ERROR_MESSAGE) { return json.toStringUtf8(); } - else { - return json.slice(0, MAX_JSON_LENGTH_IN_ERROR_MESSAGE).toStringUtf8() + "...(truncated)"; - } + return json.slice(0, MAX_JSON_LENGTH_IN_ERROR_MESSAGE).toStringUtf8() + "...(truncated)"; } public static boolean canCastToJson(Type type) diff --git a/core/trino-main/src/main/java/io/trino/util/MoreMath.java b/core/trino-main/src/main/java/io/trino/util/MoreMath.java index 06820b18b384..402e0056108d 100644 --- a/core/trino-main/src/main/java/io/trino/util/MoreMath.java +++ b/core/trino-main/src/main/java/io/trino/util/MoreMath.java @@ -34,14 +34,12 @@ public static boolean nearlyEqual(double a, double b, double epsilon) if (a == b) { // shortcut, handles infinities return true; } - else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) { + if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) { // a or b is zero or both are extremely close to it // relative error is less meaningful here return diff < (epsilon * Double.MIN_NORMAL); - } - else { // use relative error - return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon; - } + } // use relative error + return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon; } /** @@ -56,14 +54,12 @@ public static boolean nearlyEqual(float a, float b, float epsilon) if (a == b) { // shortcut, handles infinities return true; } - else if (a == 0 || b == 0 || diff < Float.MIN_NORMAL) { + if (a == 0 || b == 0 || diff < Float.MIN_NORMAL) { // a or b is zero or both are extremely close to it // relative error is less meaningful here return diff < (epsilon * Float.MIN_NORMAL); - } - else { // use relative error - return diff / Math.min((absA + absB), Float.MAX_VALUE) < epsilon; - } + } // use relative error + return diff / Math.min((absA + absB), Float.MAX_VALUE) < epsilon; } public static double min(double... values) diff --git a/core/trino-main/src/main/java/io/trino/util/Optionals.java b/core/trino-main/src/main/java/io/trino/util/Optionals.java index dc5fd6f8f684..3d9e29977952 100644 --- a/core/trino-main/src/main/java/io/trino/util/Optionals.java +++ b/core/trino-main/src/main/java/io/trino/util/Optionals.java @@ -26,11 +26,9 @@ public static Optional combine(Optional left, Optional right, Binar if (left.isPresent() && right.isPresent()) { return Optional.of(combiner.apply(left.get(), right.get())); } - else if (left.isPresent()) { + if (left.isPresent()) { return left; } - else { - return right; - } + return right; } } diff --git a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java index 95a38b7d5c3e..456d0d98eb67 100644 --- a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java +++ b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java @@ -404,9 +404,7 @@ private static long getCompactedBlockSizeInBytes(Block block) // dictionary blocks might become unwrapped when copyRegion is called on a block that is already compact return ((DictionaryBlock) block).compact().getSizeInBytes(); } - else { - return copyBlockViaCopyRegion(block).getSizeInBytes(); - } + return copyBlockViaCopyRegion(block).getSizeInBytes(); } private static Block copyBlockViaCopyRegion(Block block) diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index 43bdf21962be..f4d769c4e011 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -135,14 +135,15 @@ public void testSimple() OptionalInt.empty())), ImmutableList.of(TABLE_SCAN_NODE_ID)); TaskContext taskContext = newTestingTaskContext(taskNotificationExecutor, driverYieldExecutor, taskStateMachine); - SqlTaskExecution sqlTaskExecution = SqlTaskExecution.createSqlTaskExecution( + SqlTaskExecution sqlTaskExecution = new SqlTaskExecution( taskStateMachine, taskContext, outputBuffer, localExecutionPlan, taskExecutor, - taskNotificationExecutor, - createTestSplitMonitor()); + createTestSplitMonitor(), + taskNotificationExecutor); + sqlTaskExecution.start(); // // test body diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java index ba743d7244c3..c2ecc8ca631a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java @@ -18,12 +18,16 @@ import io.airlift.json.JsonCodec; import io.airlift.stats.Distribution; import io.airlift.stats.Distribution.DistributionSnapshot; +import io.airlift.stats.TDigest; import io.airlift.units.DataSize; import io.airlift.units.Duration; +import io.trino.plugin.base.metrics.TDigestHistogram; import io.trino.spi.eventlistener.StageGcStatistics; import org.joda.time.DateTime; import org.testng.annotations.Test; +import java.util.Optional; + import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.testng.Assert.assertEquals; @@ -87,6 +91,7 @@ public class TestStageStats new Duration(202, NANOSECONDS), DataSize.ofBytes(34), + Optional.of(getTDigestHistogram(10)), DataSize.ofBytes(35), DataSize.ofBytes(36), 37, @@ -177,6 +182,7 @@ private static void assertExpectedStageStats(StageStats actual) assertEquals(actual.getFailedInputBlockedTime(), new Duration(202, NANOSECONDS)); assertEquals(actual.getBufferedDataSize(), DataSize.ofBytes(34)); + assertEquals(actual.getOutputBufferUtilization().get().getMax(), 9.0); assertEquals(actual.getOutputDataSize(), DataSize.ofBytes(35)); assertEquals(actual.getFailedOutputDataSize(), DataSize.ofBytes(36)); assertEquals(actual.getOutputPositions(), 37); @@ -205,4 +211,13 @@ private static DistributionSnapshot getTestDistribution(int count) } return distribution.snapshot(); } + + private static TDigestHistogram getTDigestHistogram(int count) + { + TDigest digest = new TDigest(); + for (int i = 0; i < count; i++) { + digest.add(i); + } + return new TDigestHistogram(digest); + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java index db478391af51..6c6324ecea5c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.stats.TDigest; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; @@ -31,6 +32,7 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.operator.TaskStats; +import io.trino.plugin.base.metrics.TDigestHistogram; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanNodeId; @@ -146,7 +148,8 @@ public TaskInfo getTaskInfo() 0, 0, 0, - ImmutableList.of()), + ImmutableList.of(), + Optional.of(new TDigestHistogram(new TDigest()))), ImmutableSet.copyOf(noMoreSplits), new TaskStats(DateTime.now(), null), Optional.empty(), diff --git a/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java b/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java index b35a14e1fc99..60b602dcb87b 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.stats.TDigest; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.execution.TaskId; @@ -26,6 +27,7 @@ import io.trino.execution.buffer.BufferState; import io.trino.execution.buffer.OutputBufferInfo; import io.trino.operator.TaskStats; +import io.trino.plugin.base.metrics.TDigestHistogram; import org.joda.time.DateTime; import org.testng.annotations.Test; @@ -247,7 +249,8 @@ private static TaskInfo buildTaskInfo(TaskId taskId, TaskState state, Duration s 0, 0, 0, - ImmutableList.of()), + ImmutableList.of(), + Optional.of(new TDigestHistogram(new TDigest()))), ImmutableSet.of(), new TaskStats(DateTime.now(), null, diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java b/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java index 11cab519d155..9cfa25ccb960 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java @@ -1239,9 +1239,7 @@ private Optional bindVariables() if (returnType == null) { return signatureBinder.bindVariables(argumentTypes); } - else { - return signatureBinder.bindVariables(argumentTypes, returnType.getTypeSignature()); - } + return signatureBinder.bindVariables(argumentTypes, returnType.getTypeSignature()); } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java b/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java index bc57a56041b6..8d3b708b14b7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/BenchmarkScanFilterAndProjectOperator.java @@ -239,9 +239,7 @@ private static Page createPage(List types, int positions, boolea if (dictionary) { return SequencePageBuilder.createSequencePageWithDictionaryBlocks(types, positions); } - else { - return SequencePageBuilder.createSequencePage(types, positions); - } + return SequencePageBuilder.createSequencePage(types, positions); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java index 3bc935e7e01e..afd4d1495774 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java @@ -272,6 +272,7 @@ private SourceOperator createExchangeOperator() SourceOperator operator = operatorFactory.createOperator(driverContext); assertEquals(getOnlyElement(operator.getOperatorContext().getNestedOperatorStats()).getUserMemoryReservation().toBytes(), 0); + operatorFactory.noMoreOperators(); return operator; } @@ -292,9 +293,7 @@ private static List waitForPages(Operator operator, int expectedPageCount) greaterThanZero = true; break; } - else { - Thread.sleep(10); - } + Thread.sleep(10); } assertTrue(greaterThanZero); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java b/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java index a800874367da..a93c1a5c680e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java @@ -88,7 +88,7 @@ public Response handle(Request request) output.writeBytes(serializedPage); return new TestingResponse(HttpStatus.OK, headers.build(), output.slice().getInput()); } - else if (taskBuffer.isFinished()) { + if (taskBuffer.isFinished()) { headers.put(TRINO_PAGE_NEXT_TOKEN, String.valueOf(pageToken)); headers.put(TRINO_BUFFER_COMPLETE, String.valueOf(true)); DynamicSliceOutput output = new DynamicSliceOutput(8); @@ -97,10 +97,8 @@ else if (taskBuffer.isFinished()) { output.writeInt(0); return new TestingResponse(HttpStatus.OK, headers.build(), output.slice().getInput()); } - else { - headers.put(TRINO_PAGE_NEXT_TOKEN, String.valueOf(pageToken)); - headers.put(TRINO_BUFFER_COMPLETE, String.valueOf(false)); - return new TestingResponse(HttpStatus.NO_CONTENT, headers.build(), new byte[0]); - } + headers.put(TRINO_PAGE_NEXT_TOKEN, String.valueOf(pageToken)); + headers.put(TRINO_BUFFER_COMPLETE, String.valueOf(false)); + return new TestingResponse(HttpStatus.NO_CONTENT, headers.build(), new byte[0]); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java index e2a55f3a8313..6af897c3231f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java @@ -164,11 +164,9 @@ private Page createPage(List values, double maxStandardError) if (values.isEmpty()) { return new Page(0); } - else { - return new Page(values.size(), - createBlock(getValueType(), values), - createBlock(DOUBLE, ImmutableList.copyOf(Collections.nCopies(values.size(), maxStandardError)))); - } + return new Page(values.size(), + createBlock(getValueType(), values), + createBlock(DOUBLE, ImmutableList.copyOf(Collections.nCopies(values.size(), maxStandardError)))); } /** diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateSetGeneric.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateSetGeneric.java index b5304fbd0179..26a0f2f265b8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateSetGeneric.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateSetGeneric.java @@ -190,9 +190,7 @@ private Page createPage(List values) if (values.isEmpty()) { return new Page(0); } - else { - return new Page(values.size(), createBlock(getValueType(), values)); - } + return new Page(values.size(), createBlock(getValueType(), values)); } /** diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/GroupByAggregationTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/GroupByAggregationTestUtils.java index d46f9df67f9f..1a4a9503d67c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/GroupByAggregationTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/GroupByAggregationTestUtils.java @@ -35,18 +35,16 @@ public static Page[] createPages(Block[] blocks) if (positions == 0) { return new Page[] {}; } - else if (positions == 1) { + if (positions == 1) { return new Page[] {new Page(positions, blocks)}; } - else { - int split = positions / 2; // [0, split - 1] goes to first list of blocks; [split, positions - 1] goes to second list of blocks. - Block[] blockArray1 = new Block[blocks.length]; - Block[] blockArray2 = new Block[blocks.length]; - for (int i = 0; i < blocks.length; i++) { - blockArray1[i] = blocks[i].getRegion(0, split); - blockArray2[i] = blocks[i].getRegion(split, positions - split); - } - return new Page[] {new Page(blockArray1), new Page(blockArray2)}; + int split = positions / 2; // [0, split - 1] goes to first list of blocks; [split, positions - 1] goes to second list of blocks. + Block[] blockArray1 = new Block[blocks.length]; + Block[] blockArray2 = new Block[blocks.length]; + for (int i = 0; i < blocks.length; i++) { + blockArray1[i] = blocks[i].getRegion(0, split); + blockArray2[i] = blocks[i].getRegion(split, positions - split); } + return new Page[] {new Page(blockArray1), new Page(blockArray2)}; } } diff --git a/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java b/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java index 3720a91d13fc..a589d378fe9d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/BenchmarkHashBuildAndJoinOperators.java @@ -71,6 +71,7 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.jmh.Benchmarks.benchmark; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.operator.join.JoinBridgeManager.lookupAllAtOnce; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -239,14 +240,12 @@ public void setup(OperatorFactories operatorFactories) } JoinBridgeManager lookupSourceFactory = getLookupSourceFactoryManager(this, outputChannels, partitionCount); - joinOperatorFactory = operatorFactories.innerJoin( + joinOperatorFactory = operatorFactories.spillingJoin( + innerJoin(false, false), HASH_JOIN_OPERATOR_ID, TEST_PLAN_NODE_ID, lookupSourceFactory, false, - false, - false, - true, types, hashChannels, hashChannel, diff --git a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java index a65d20a9a806..6fb25a7411c3 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java @@ -63,6 +63,7 @@ import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static java.util.Objects.requireNonNull; @@ -92,14 +93,12 @@ public static OperatorFactory innerJoinOperatorFactory( boolean outputSingleMatch, boolean hasFilter) { - return operatorFactories.innerJoin( + return operatorFactories.spillingJoin( + innerJoin(outputSingleMatch, false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, - outputSingleMatch, - false, hasFilter, - true, probePages.getTypes(), probePages.getHashChannels().orElseThrow(), getHashChannelAsInt(probePages), diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java index cd8056526a75..b92d2b57e239 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java @@ -93,6 +93,10 @@ import static io.trino.operator.OperatorAssertion.assertOperatorEquals; import static io.trino.operator.OperatorAssertion.dropChannel; import static io.trino.operator.OperatorAssertion.without; +import static io.trino.operator.OperatorFactories.JoinOperatorType.fullOuterJoin; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; +import static io.trino.operator.OperatorFactories.JoinOperatorType.lookupOuterJoin; +import static io.trino.operator.OperatorFactories.JoinOperatorType.probeOuterJoin; import static io.trino.operator.WorkProcessor.ProcessState.finished; import static io.trino.operator.WorkProcessor.ProcessState.ofResult; import static io.trino.operator.join.JoinTestUtils.buildLookupSource; @@ -293,13 +297,11 @@ public void testUnwrapsLazyBlocks() .map(page -> new Page(page.getBlock(0), new LazyBlock(1, () -> page.getBlock(1)))) .collect(toImmutableList()); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactory, - false, - false, - true, true, probePages.getTypes(), Ints.asList(0), @@ -348,13 +350,11 @@ public void testYield() // probe matching the above 40 entries RowPagesBuilder probePages = rowPagesBuilder(false, Ints.asList(0), ImmutableList.of(BIGINT)); List probeInput = probePages.addSequencePage(100, 0).build(); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactory, - false, - false, - true, true, probePages.getTypes(), Ints.asList(0), @@ -1235,14 +1235,12 @@ public void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean pr // probe factory List probeTypes = ImmutableList.of(VARCHAR); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), @@ -1276,13 +1274,12 @@ public void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild, bool // probe factory List probeTypes = ImmutableList.of(VARCHAR); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); - OperatorFactory joinOperatorFactory = operatorFactories.lookupOuterJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + lookupOuterJoin(false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), @@ -1322,13 +1319,12 @@ public void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, boole .row((String) null) .row("c") .build(); - OperatorFactory joinOperatorFactory = operatorFactories.probeOuterJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + probeOuterJoin(false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), @@ -1371,12 +1367,12 @@ public void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolea .row((String) null) .row("c") .build(); - OperatorFactory joinOperatorFactory = operatorFactories.fullOuterJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + fullOuterJoin(), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), @@ -1418,14 +1414,12 @@ public void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallelB List probeTypes = ImmutableList.of(VARCHAR); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); List probeInput = probePages.build(); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), @@ -1612,14 +1606,12 @@ private OperatorFactory createJoinOperatorFactoryWithBlockingLookupSource(TaskCo // probe factory List probeTypes = ImmutableList.of(VARCHAR); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.spillingJoin( + innerJoin(false, waitForBuild), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - waitForBuild, - false, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), @@ -1666,13 +1658,12 @@ private OperatorFactory probeOuterJoinOperatorFactory( RowPagesBuilder probePages, boolean hasFilter) { - return operatorFactories.probeOuterJoin( + return operatorFactories.spillingJoin( + probeOuterJoin(false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, - false, hasFilter, - true, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/BenchmarkHashBuildAndJoinOperators.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/BenchmarkHashBuildAndJoinOperators.java index 7e80fc8750a7..a388b8cc9997 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/BenchmarkHashBuildAndJoinOperators.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/BenchmarkHashBuildAndJoinOperators.java @@ -72,9 +72,9 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.jmh.Benchmarks.benchmark; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -239,20 +239,16 @@ public void setup(OperatorFactories operatorFactories) } JoinBridgeManager lookupSourceFactory = getLookupSourceFactoryManager(this, outputChannels, partitionCount); - joinOperatorFactory = operatorFactories.innerJoin( + joinOperatorFactory = operatorFactories.join( + innerJoin(false, false), HASH_JOIN_OPERATOR_ID, TEST_PLAN_NODE_ID, lookupSourceFactory, false, - false, - false, - false, types, hashChannels, hashChannel, Optional.of(outputChannels), - OptionalInt.empty(), - unsupportedPartitioningSpillerFactory(), TYPE_OPERATOR_FACTORY); buildHash(this, lookupSourceFactory, outputChannels, partitionCount); initializeProbePages(); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java index d2e1aff1dc17..71e15f8533f6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java @@ -54,6 +54,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static java.util.Objects.requireNonNull; @@ -80,20 +81,16 @@ public static OperatorFactory innerJoinOperatorFactory( boolean outputSingleMatch, boolean hasFilter) { - return operatorFactories.innerJoin( + return operatorFactories.join( + innerJoin(outputSingleMatch, false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, - outputSingleMatch, - false, hasFilter, - false, probePages.getTypes(), probePages.getHashChannels().orElseThrow(), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); } diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java index 4c5c6c0f8ca4..ac363cbc7c04 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java @@ -60,7 +60,6 @@ import java.util.List; import java.util.Optional; -import java.util.OptionalInt; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.SynchronousQueue; @@ -73,6 +72,9 @@ import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.operator.OperatorAssertion.assertOperatorEquals; +import static io.trino.operator.OperatorFactories.JoinOperatorType.fullOuterJoin; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; +import static io.trino.operator.OperatorFactories.JoinOperatorType.probeOuterJoin; import static io.trino.operator.WorkProcessor.ProcessState.finished; import static io.trino.operator.WorkProcessor.ProcessState.ofResult; import static io.trino.operator.join.unspilled.JoinTestUtils.buildLookupSource; @@ -249,20 +251,16 @@ public void testUnwrapsLazyBlocks(boolean singleBigintLookupSource) .map(page -> new Page(page.getBlock(0), new LazyBlock(1, () -> page.getBlock(1)))) .collect(toImmutableList()); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactory, - false, - false, true, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); instantiateBuildDrivers(buildSideSetup, taskContext); @@ -304,20 +302,16 @@ public void testYield(boolean singleBigintLookupSource) // probe matching the above 40 entries RowPagesBuilder probePages = rowPagesBuilder(false, Ints.asList(0), ImmutableList.of(BIGINT)); List probeInput = probePages.addSequencePage(100, 0).build(); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactory, - false, - false, true, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); instantiateBuildDrivers(buildSideSetup, taskContext); @@ -891,20 +885,16 @@ public void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean pr // probe factory List probeTypes = ImmutableList.of(BIGINT); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); // drivers and operators @@ -932,19 +922,16 @@ public void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild, bool // probe factory List probeTypes = ImmutableList.of(BIGINT); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); - OperatorFactory joinOperatorFactory = operatorFactories.lookupOuterJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + OperatorFactories.JoinOperatorType.lookupOuterJoin(false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); // drivers and operators @@ -978,19 +965,16 @@ public void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, boole .row((String) null) .row(3L) .build(); - OperatorFactory joinOperatorFactory = operatorFactories.probeOuterJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + probeOuterJoin(false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); // build drivers and operators @@ -1027,18 +1011,16 @@ public void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolea .row((String) null) .row(3L) .build(); - OperatorFactory joinOperatorFactory = operatorFactories.fullOuterJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + fullOuterJoin(), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); // build drivers and operators @@ -1074,20 +1056,16 @@ public void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallelB List probeTypes = ImmutableList.of(BIGINT); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); List probeInput = probePages.build(); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + innerJoin(false, false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); // build drivers and operators @@ -1268,20 +1246,16 @@ private OperatorFactory createJoinOperatorFactoryWithBlockingLookupSource(TaskCo // probe factory List probeTypes = ImmutableList.of(VARCHAR); RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); - OperatorFactory joinOperatorFactory = operatorFactories.innerJoin( + OperatorFactory joinOperatorFactory = operatorFactories.join( + innerJoin(false, waitForBuild), 0, new PlanNodeId("test"), lookupSourceFactoryManager, false, - waitForBuild, - false, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); // build drivers and operators @@ -1345,19 +1319,16 @@ private OperatorFactory probeOuterJoinOperatorFactory( RowPagesBuilder probePages, boolean hasFilter) { - return operatorFactories.probeOuterJoin( + return operatorFactories.join( + probeOuterJoin(false), 0, new PlanNodeId("test"), lookupSourceFactoryManager, - false, hasFilter, - false, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), - OptionalInt.of(1), - null, TYPE_OPERATOR_FACTORY); } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java index 7e0f6bba624e..183212d586d2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java @@ -14,7 +14,10 @@ package io.trino.operator.output; import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.trino.block.BlockAssertions; +import io.trino.spi.block.AbstractVariableWidthBlock; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; @@ -22,20 +25,37 @@ import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; import io.trino.spi.type.Decimals; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; import io.trino.type.BlockTypeOperators; import it.unimi.dsi.fastutil.ints.IntArrayList; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import javax.annotation.Nullable; + +import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.Function; +import java.util.function.ObjLongConsumer; import java.util.stream.IntStream; +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.airlift.testing.Assertions.assertInstanceOf; import static io.trino.block.BlockAssertions.assertBlockEquals; @@ -46,7 +66,6 @@ import static io.trino.block.BlockAssertions.createLongDecimalsBlock; import static io.trino.block.BlockAssertions.createLongTimestampBlock; import static io.trino.block.BlockAssertions.createLongsBlock; -import static io.trino.block.BlockAssertions.createRandomBlockForType; import static io.trino.block.BlockAssertions.createRandomDictionaryBlock; import static io.trino.block.BlockAssertions.createSlicesBlock; import static io.trino.block.BlockAssertions.createSmallintsBlock; @@ -55,16 +74,10 @@ import static io.trino.spi.block.DictionaryId.randomDictionaryId; import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.CharType.createCharType; import static io.trino.spi.type.DecimalType.createDecimalType; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.anonymousRow; -import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TimestampType.createTimestampType; -import static io.trino.spi.type.TinyintType.TINYINT; -import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static java.util.Objects.requireNonNull; @@ -76,7 +89,7 @@ public class TestPositionsAppender private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(new BlockTypeOperators()); @Test(dataProvider = "types") - public void testMixedBlockTypes(Type type) + public void testMixedBlockTypes(TestType type) { List input = ImmutableList.of( input(emptyBlock(type)), @@ -103,16 +116,16 @@ public void testMixedBlockTypes(Type type) testAppend(type, input); } - @Test(dataProvider = "nullRleTypes") - public void testNullRle(Type type) + @Test(dataProvider = "types") + public void testNullRle(TestType type) { - testNullRle(type, nullBlock(type, 2)); - testNullRle(type, nullRleBlock(type, 2)); - testNullRle(type, createRandomBlockForType(type, 4, 0.5f)); + testNullRle(type.getType(), nullBlock(type, 2)); + testNullRle(type.getType(), nullRleBlock(type, 2)); + testNullRle(type.getType(), createRandomBlockForType(type, 4, 0.5f)); } @Test(dataProvider = "types") - public void testRleSwitchToFlat(Type type) + public void testRleSwitchToFlat(TestType type) { List inputs = ImmutableList.of( input(rleBlock(type, 3), 0, 1), @@ -126,7 +139,7 @@ public void testRleSwitchToFlat(Type type) } @Test(dataProvider = "types") - public void testFlatAppendRle(Type type) + public void testFlatAppendRle(TestType type) { List inputs = ImmutableList.of( input(notNullBlock(type, 2), 0, 1), @@ -140,7 +153,7 @@ public void testFlatAppendRle(Type type) } @Test(dataProvider = "differentValues") - public void testMultipleRleBlocksWithDifferentValues(Type type, Block value1, Block value2) + public void testMultipleRleBlocksWithDifferentValues(TestType type, Block value1, Block value2) { List input = ImmutableList.of( input(rleBlock(value1, 3), 0, 1), @@ -153,28 +166,27 @@ public static Object[][] differentValues() { return new Object[][] { - {BIGINT, createLongsBlock(0), createLongsBlock(1)}, - {BOOLEAN, createBooleansBlock(true), createBooleansBlock(false)}, - {INTEGER, createIntsBlock(0), createIntsBlock(1)}, - {createCharType(10), createStringsBlock("0"), createStringsBlock("1")}, - {createUnboundedVarcharType(), createStringsBlock("0"), createStringsBlock("1")}, - {DOUBLE, createDoublesBlock(0D), createDoublesBlock(1D)}, - {SMALLINT, createSmallintsBlock(0), createSmallintsBlock(1)}, - {TINYINT, createTinyintsBlock(0), createTinyintsBlock(1)}, - {VARBINARY, createSlicesBlock(Slices.wrappedLongArray(0)), createSlicesBlock(Slices.wrappedLongArray(1))}, - {createDecimalType(Decimals.MAX_SHORT_PRECISION + 1), createLongDecimalsBlock("0"), createLongDecimalsBlock("1")}, - {new ArrayType(BIGINT), createArrayBigintBlock(ImmutableList.of(ImmutableList.of(0L))), createArrayBigintBlock(ImmutableList.of(ImmutableList.of(1L)))}, - { - createTimestampType(9), - createLongTimestampBlock(createTimestampType(9), new LongTimestamp(0, 0)), - createLongTimestampBlock(createTimestampType(9), new LongTimestamp(1, 0))} + {TestType.BIGINT, createLongsBlock(0), createLongsBlock(1)}, + {TestType.BOOLEAN, createBooleansBlock(true), createBooleansBlock(false)}, + {TestType.INTEGER, createIntsBlock(0), createIntsBlock(1)}, + {TestType.CHAR_10, createStringsBlock("0"), createStringsBlock("1")}, + {TestType.VARCHAR, createStringsBlock("0"), createStringsBlock("1")}, + {TestType.DOUBLE, createDoublesBlock(0D), createDoublesBlock(1D)}, + {TestType.SMALLINT, createSmallintsBlock(0), createSmallintsBlock(1)}, + {TestType.TINYINT, createTinyintsBlock(0), createTinyintsBlock(1)}, + {TestType.VARBINARY, createSlicesBlock(Slices.wrappedLongArray(0)), createSlicesBlock(Slices.wrappedLongArray(1))}, + {TestType.LONG_DECIMAL, createLongDecimalsBlock("0"), createLongDecimalsBlock("1")}, + {TestType.ARRAY_BIGINT, createArrayBigintBlock(ImmutableList.of(ImmutableList.of(0L))), createArrayBigintBlock(ImmutableList.of(ImmutableList.of(1L)))}, + {TestType.LONG_TIMESTAMP, createLongTimestampBlock(createTimestampType(9), new LongTimestamp(0, 0)), + createLongTimestampBlock(createTimestampType(9), new LongTimestamp(1, 0))}, + {TestType.VARCHAR_WITH_TEST_BLOCK, TestVariableWidthBlock.adapt(createStringsBlock("0")), TestVariableWidthBlock.adapt(createStringsBlock("1"))} }; } @Test(dataProvider = "types") - public void testMultipleRleWithTheSameValueProduceRle(Type type) + public void testMultipleRleWithTheSameValueProduceRle(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); Block value = notNullBlock(type, 1); positionsAppender.append(allPositions(3), rleBlock(value, 3)); @@ -186,9 +198,9 @@ public void testMultipleRleWithTheSameValueProduceRle(Type type) } @Test(dataProvider = "types") - public void testConsecutiveBuilds(Type type) + public void testConsecutiveBuilds(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // empty block positionsAppender.append(positions(), emptyBlock(type)); @@ -204,12 +216,12 @@ public void testConsecutiveBuilds(Type type) // append null and not null position positionsAppender.append(allPositions(2), block); - assertBlockEquals(type, positionsAppender.build(), block); + assertBlockEquals(type.getType(), positionsAppender.build(), block); // append not null rle Block rleBlock = rleBlock(type, 1); positionsAppender.append(allPositions(1), rleBlock); - assertBlockEquals(type, positionsAppender.build(), rleBlock); + assertBlockEquals(type.getType(), positionsAppender.build(), rleBlock); // append empty rle positionsAppender.append(positions(), rleBlock(type, 0)); @@ -218,7 +230,7 @@ public void testConsecutiveBuilds(Type type) // append null rle Block nullRleBlock = nullRleBlock(type, 1); positionsAppender.append(allPositions(1), nullRleBlock); - assertBlockEquals(type, positionsAppender.build(), nullRleBlock); + assertBlockEquals(type.getType(), positionsAppender.build(), nullRleBlock); // just build to confirm appender was reset assertEquals(positionsAppender.build().getPositionCount(), 0); @@ -246,9 +258,9 @@ public void testRowWithNestedFields() { RowType type = anonymousRow(BIGINT, BIGINT, VARCHAR); Block rowBLock = RowBlock.fromFieldBlocks(2, Optional.empty(), new Block[] { - notNullBlock(BIGINT, 2), - dictionaryBlock(BIGINT, 2, 2, 0.5F), - rleBlock(VARCHAR, 2) + notNullBlock(TestType.BIGINT, 2), + dictionaryBlock(TestType.BIGINT, 2, 2, 0.5F), + rleBlock(TestType.VARCHAR, 2) }); PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -259,45 +271,12 @@ public void testRowWithNestedFields() assertBlockEquals(type, actual, rowBLock); } - @DataProvider(name = "nullRleTypes") - public static Object[][] nullRleTypes() - { - return new Object[][] - { - {BIGINT}, - {BOOLEAN}, - {INTEGER}, - {createCharType(10)}, - {createUnboundedVarcharType()}, - {DOUBLE}, - {SMALLINT}, - {TINYINT}, - {VARBINARY}, - {createDecimalType(Decimals.MAX_SHORT_PRECISION + 1)}, - {createTimestampType(9)}, - {anonymousRow(BIGINT, VARCHAR)} - }; - } - @DataProvider(name = "types") public static Object[][] types() { - return new Object[][] - { - {BIGINT}, - {BOOLEAN}, - {INTEGER}, - {createCharType(10)}, - {createUnboundedVarcharType()}, - {DOUBLE}, - {SMALLINT}, - {TINYINT}, - {VARBINARY}, - {createDecimalType(Decimals.MAX_SHORT_PRECISION + 1)}, - {new ArrayType(BIGINT)}, - {createTimestampType(9)}, - {anonymousRow(BIGINT, VARCHAR)} - }; + return Arrays.stream(TestType.values()) + .map(type -> new Object[] {type}) + .toArray(Object[][]::new); } private static Block singleValueBlock(String value) @@ -332,7 +311,7 @@ private DictionaryBlock dictionaryBlock(Block dictionary, int[] ids) return new DictionaryBlock(0, ids.length, dictionary, ids, false, randomDictionaryId()); } - private DictionaryBlock dictionaryBlock(Type type, int positionCount, int dictionarySize, float nullRate) + private DictionaryBlock dictionaryBlock(TestType type, int positionCount, int dictionarySize, float nullRate) { Block dictionary = createRandomBlockForType(type, dictionarySize, nullRate); return createRandomDictionaryBlock(dictionary, positionCount); @@ -343,40 +322,45 @@ private RunLengthEncodedBlock rleBlock(Block value, int positionCount) return new RunLengthEncodedBlock(value, positionCount); } - private RunLengthEncodedBlock rleBlock(Type type, int positionCount) + private RunLengthEncodedBlock rleBlock(TestType type, int positionCount) { Block rleValue = createRandomBlockForType(type, 1, 0); return new RunLengthEncodedBlock(rleValue, positionCount); } - private RunLengthEncodedBlock nullRleBlock(Type type, int positionCount) + private RunLengthEncodedBlock nullRleBlock(TestType type, int positionCount) { Block rleValue = nullBlock(type, 1); return new RunLengthEncodedBlock(rleValue, positionCount); } - private Block partiallyNullBlock(Type type, int positionCount) + private Block partiallyNullBlock(TestType type, int positionCount) { return createRandomBlockForType(type, positionCount, 0.5F); } - private Block notNullBlock(Type type, int positionCount) + private Block notNullBlock(TestType type, int positionCount) { return createRandomBlockForType(type, positionCount, 0); } - private Block nullBlock(Type type, int positionCount) + private Block nullBlock(TestType type, int positionCount) { - BlockBuilder blockBuilder = type.createBlockBuilder(null, positionCount); + BlockBuilder blockBuilder = type.getType().createBlockBuilder(null, positionCount); for (int i = 0; i < positionCount; i++) { blockBuilder.appendNull(); } - return blockBuilder.build(); + return type.adapt(blockBuilder.build()); + } + + private Block emptyBlock(TestType type) + { + return type.adapt(type.getType().createBlockBuilder(null, 0).build()); } - private Block emptyBlock(Type type) + private Block createRandomBlockForType(TestType type, int positionCount, float nullRate) { - return type.createBlockBuilder(null, 0).build(); + return type.adapt(BlockAssertions.createRandomBlockForType(type.getType(), positionCount, nullRate)); } private void testNullRle(Type type, Block source) @@ -398,9 +382,9 @@ private void testNullRle(Type type, Block source) assertInstanceOf(actual, RunLengthEncodedBlock.class); } - private void testAppend(Type type, List inputs) + private void testAppend(TestType type, List inputs) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); inputs.forEach(input -> positionsAppender.append(input.getPositions(), input.getBlock())); @@ -408,7 +392,7 @@ private void testAppend(Type type, List inputs) assertGreaterThanOrEqual(positionsAppender.getRetainedSizeInBytes(), sizeInBytes); Block actual = positionsAppender.build(); - assertBlockIsValid(actual, sizeInBytes, type, inputs); + assertBlockIsValid(actual, sizeInBytes, type.getType(), inputs); // verify positionsAppender reset assertEquals(positionsAppender.getSizeInBytes(), 0); assertEquals(positionsAppender.getRetainedSizeInBytes(), initialRetainedSize); @@ -437,6 +421,48 @@ private Block buildBlock(Type type, List inputs, BlockBuilderStatus b return blockBuilder.build(); } + private enum TestType + { + BIGINT(BigintType.BIGINT), + BOOLEAN(BooleanType.BOOLEAN), + INTEGER(IntegerType.INTEGER), + CHAR_10(createCharType(10)), + VARCHAR(createUnboundedVarcharType()), + DOUBLE(DoubleType.DOUBLE), + SMALLINT(SmallintType.SMALLINT), + TINYINT(TinyintType.TINYINT), + VARBINARY(VarbinaryType.VARBINARY), + LONG_DECIMAL(createDecimalType(Decimals.MAX_SHORT_PRECISION + 1)), + LONG_TIMESTAMP(createTimestampType(9)), + ROW_BIGINT_VARCHAR(anonymousRow(BigintType.BIGINT, VarcharType.VARCHAR)), + ARRAY_BIGINT(new ArrayType(BigintType.BIGINT)), + VARCHAR_WITH_TEST_BLOCK(VarcharType.VARCHAR, TestVariableWidthBlock.adaptation()); + + private final Type type; + private final Function blockAdaptation; + + TestType(Type type) + { + this(type, Function.identity()); + } + + TestType(Type type, Function blockAdaptation) + { + this.type = requireNonNull(type, "type is null"); + this.blockAdaptation = requireNonNull(blockAdaptation, "blockAdaptation is null"); + } + + public Block adapt(Block block) + { + return blockAdaptation.apply(block); + } + + public Type getType() + { + return type; + } + } + private static class BlockView { private final Block block; @@ -463,4 +489,160 @@ public void appendTo(PositionsAppender positionsAppender) positionsAppender.append(getPositions(), getBlock()); } } + + private static class TestVariableWidthBlock + extends AbstractVariableWidthBlock + { + private final int arrayOffset; + private final int positionCount; + private final Slice slice; + private final int[] offsets; + @Nullable + private final boolean[] valueIsNull; + + private static Function adaptation() + { + return TestVariableWidthBlock::adapt; + } + + private static Block adapt(Block block) + { + if (block instanceof RunLengthEncodedBlock) { + checkArgument(block.getPositionCount() == 0 || block.isNull(0)); + return new RunLengthEncodedBlock(new TestVariableWidthBlock(0, 1, EMPTY_SLICE, new int[] {0, 0}, new boolean[] {true}), block.getPositionCount()); + } + + int[] offsets = new int[block.getPositionCount() + 1]; + boolean[] valueIsNull = new boolean[block.getPositionCount()]; + boolean hasNullValue = false; + for (int i = 0; i < block.getPositionCount(); i++) { + if (block.isNull(i)) { + valueIsNull[i] = true; + hasNullValue = true; + offsets[i + 1] = offsets[i]; + } + else { + offsets[i + 1] = offsets[i] + block.getSliceLength(i); + } + } + + return new TestVariableWidthBlock(0, block.getPositionCount(), ((VariableWidthBlock) block).getRawSlice(), offsets, hasNullValue ? valueIsNull : null); + } + + private TestVariableWidthBlock(int arrayOffset, int positionCount, Slice slice, int[] offsets, boolean[] valueIsNull) + { + checkArgument(arrayOffset >= 0); + this.arrayOffset = arrayOffset; + checkArgument(positionCount >= 0); + this.positionCount = positionCount; + this.slice = requireNonNull(slice, "slice is null"); + this.offsets = offsets; + this.valueIsNull = valueIsNull; + } + + @Override + protected Slice getRawSlice(int position) + { + return slice; + } + + @Override + protected int getPositionOffset(int position) + { + return offsets[position + arrayOffset]; + } + + @Override + public int getSliceLength(int position) + { + return getPositionOffset(position + 1) - getPositionOffset(position); + } + + @Override + protected boolean isEntryNull(int position) + { + return valueIsNull != null && valueIsNull[position + arrayOffset]; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public Block getRegion(int positionOffset, int length) + { + return new TestVariableWidthBlock(positionOffset + arrayOffset, length, slice, offsets, valueIsNull); + } + + @Override + public Block getSingleValueBlock(int position) + { + if (isNull(position)) { + return new TestVariableWidthBlock(0, 1, EMPTY_SLICE, new int[] {0, 0}, new boolean[] {true}); + } + + int offset = getPositionOffset(position); + int entrySize = getSliceLength(position); + + Slice copy = Slices.copyOf(getRawSlice(position), offset, entrySize); + + return new TestVariableWidthBlock(0, 1, copy, new int[] {0, copy.length()}, null); + } + + @Override + public long getSizeInBytes() + { + throw new UnsupportedOperationException(); + } + + @Override + public long getRegionSizeInBytes(int position, int length) + { + throw new UnsupportedOperationException(); + } + + @Override + public OptionalInt fixedSizeInBytesPerPosition() + { + return OptionalInt.empty(); + } + + @Override + public long getPositionsSizeInBytes(boolean[] positions, int selectedPositionsCount) + { + throw new UnsupportedOperationException(); + } + + @Override + public long getRetainedSizeInBytes() + { + throw new UnsupportedOperationException(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block copyPositions(int[] positions, int offset, int length) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block copyRegion(int position, int length) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block copyWithAppendedNull() + { + throw new UnsupportedOperationException(); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java new file mode 100644 index 000000000000..a102a75667c8 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.output; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.VariableWidthBlock; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Optional; + +import static io.trino.block.BlockAssertions.assertBlockEquals; +import static io.trino.block.BlockAssertions.createStringsBlock; +import static io.trino.operator.output.SlicePositionsAppender.duplicateBytes; +import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.testng.internal.junit.ArrayAsserts.assertArrayEquals; + +public class TestSlicePositionsAppender +{ + @Test + public void testAppendEmptySliceRle() + { + // test SlicePositionAppender.appendRle with empty value (Slice with length 0) + PositionsAppender positionsAppender = new SlicePositionsAppender(1, 100); + RunLengthEncodedBlock rleBlock = new RunLengthEncodedBlock(createStringsBlock(""), 10); + positionsAppender.appendRle(rleBlock); + + Block actualBlock = positionsAppender.build(); + + assertBlockEquals(VARCHAR, actualBlock, rleBlock); + } + + // test append with VariableWidthBlock using Slice not backed by byte array + // to test special handling in SlicePositionsAppender.copyBytes + @Test + public void testAppendSliceNotBackedByByteArray() + { + PositionsAppender positionsAppender = new SlicePositionsAppender(1, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + Block block = new VariableWidthBlock(3, Slices.wrappedLongArray(257, 2), new int[] {0, 1, Long.BYTES, 2 * Long.BYTES}, Optional.empty()); + positionsAppender.append(IntArrayList.wrap(new int[] {0, 2}), block); + + Block actual = positionsAppender.build(); + + Block expected = new VariableWidthBlock( + 2, + Slices.wrappedBuffer(new byte[] {1, 2, 0, 0, 0, 0, 0, 0, 0}), + new int[] {0, 1, Long.BYTES + 1}, + Optional.empty()); + assertBlockEquals(VARCHAR, actual, expected); + } + + @Test + public void testDuplicateZeroLength() + { + Slice slice = Slices.wrappedBuffer(); + byte[] target = new byte[] {-1}; + duplicateBytes(slice, target, 0, 100); + assertArrayEquals(new byte[] {-1}, target); + } + + @Test + public void testDuplicate1Byte() + { + Slice slice = Slices.wrappedBuffer(new byte[] {2}); + byte[] target = new byte[5]; + Arrays.fill(target, (byte) -1); + duplicateBytes(slice, target, 3, 2); + assertArrayEquals(new byte[] {-1, -1, -1, 2, 2}, target); + } + + @Test + public void testDuplicate2Bytes() + { + Slice slice = Slices.wrappedBuffer(new byte[] {1, 2}); + byte[] target = new byte[8]; + Arrays.fill(target, (byte) -1); + duplicateBytes(slice, target, 1, 3); + assertArrayEquals(new byte[] {-1, 1, 2, 1, 2, 1, 2, -1}, target); + } + + @Test + public void testDuplicate1Time() + { + Slice slice = Slices.wrappedBuffer(new byte[] {1, 2}); + byte[] target = new byte[8]; + Arrays.fill(target, (byte) -1); + + duplicateBytes(slice, target, 1, 1); + + assertArrayEquals(new byte[] {-1, 1, 2, -1, -1, -1, -1, -1}, target); + } + + @Test + public void testDuplicateMultipleBytesOffNumberOfTimes() + { + Slice slice = Slices.wrappedBuffer(new byte[] {5, 3, 1}); + byte[] target = new byte[17]; + Arrays.fill(target, (byte) -1); + + duplicateBytes(slice, target, 1, 5); + + assertArrayEquals(new byte[] {-1, 5, 3, 1, 5, 3, 1, 5, 3, 1, 5, 3, 1, 5, 3, 1, -1}, target); + } + + @Test + public void testDuplicateMultipleBytesEvenNumberOfTimes() + { + Slice slice = Slices.wrappedBuffer(new byte[] {5, 3, 1}); + byte[] target = new byte[20]; + Arrays.fill(target, (byte) -1); + + duplicateBytes(slice, target, 1, 6); + + assertArrayEquals(new byte[] {-1, 5, 3, 1, 5, 3, 1, 5, 3, 1, 5, 3, 1, 5, 3, 1, 5, 3, 1, -1}, target); + } + + @Test + public void testDuplicateMultipleBytesPowerOfTwoNumberOfTimes() + { + Slice slice = Slices.wrappedBuffer(new byte[] {5, 3, 1}); + byte[] target = new byte[14]; + Arrays.fill(target, (byte) -1); + + duplicateBytes(slice, target, 1, 4); + + assertArrayEquals(new byte[] {-1, 5, 3, 1, 5, 3, 1, 5, 3, 1, 5, 3, 1, -1}, target); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java index 27bb54a41e95..e25720198d85 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java @@ -159,7 +159,7 @@ protected void testUnsupportedExtract(String extractField) { types().forEach(type -> { String expression = format("EXTRACT(%s FROM CAST(NULL AS %s))", extractField, type); - assertThatThrownBy(() -> assertions.expression(expression), expression) + assertThatThrownBy(() -> assertions.expression(expression).evaluate(), expression) .as(expression) .isInstanceOf(TrinoException.class) .hasMessageMatching(format("line 1:\\d+:\\Q Cannot extract %s from %s", extractField, type)); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java index 09f1a5b5f292..5202dd415146 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java @@ -85,11 +85,19 @@ public final void destroyTestFunctions() functionAssertions = null; } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#function(String, String...)} + */ + @Deprecated protected void assertFunction(@Language("SQL") String projection, Type expectedType, Object expected) { functionAssertions.assertFunction(projection, expectedType, expected); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#operator(OperatorType, String...)} + */ + @Deprecated protected void assertOperator(OperatorType operator, String value, Type expectedType, Object expected) { functionAssertions.assertFunction(format("\"%s\"(%s)", mangleOperatorName(operator), value), expectedType, expected); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java index a2497c0f6367..0d1cb166ec82 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java @@ -141,10 +141,10 @@ private static String generateRandomJsonValue(Type valueType) if (valueType == BIGINT) { return Long.toString(ThreadLocalRandom.current().nextLong()); } - else if (valueType == DOUBLE) { + if (valueType == DOUBLE) { return Double.toString(ThreadLocalRandom.current().nextDouble()); } - else if (valueType == VARCHAR) { + if (valueType == VARCHAR) { String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; int length = ThreadLocalRandom.current().nextInt(10) + 1; @@ -156,9 +156,7 @@ else if (valueType == VARCHAR) { builder.append('"'); return builder.toString(); } - else { - throw new UnsupportedOperationException(); - } + throw new UnsupportedOperationException(); } public PageProcessor getPageProcessor() diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java index 80df4518f490..8bb2df88515d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java @@ -147,10 +147,10 @@ private static String generateRandomJsonValue(Type valueType) if (valueType == BIGINT) { return Long.toString(ThreadLocalRandom.current().nextLong()); } - else if (valueType == DOUBLE) { + if (valueType == DOUBLE) { return Double.toString(ThreadLocalRandom.current().nextDouble()); } - else if (valueType == VARCHAR) { + if (valueType == VARCHAR) { String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; int length = ThreadLocalRandom.current().nextInt(10) + 1; @@ -162,9 +162,7 @@ else if (valueType == VARCHAR) { builder.append('"'); return builder.toString(); } - else { - throw new UnsupportedOperationException(); - } + throw new UnsupportedOperationException(); } public PageProcessor getPageProcessor() diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java index 3c7f9eeca120..f2fdcbb55734 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java @@ -291,6 +291,10 @@ public void installPlugin(Plugin plugin) runner.installPlugin(plugin); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#function(String, String...)} + */ + @Deprecated public void assertFunction(String projection, Type expectedType, Object expected) { if (expected instanceof Slice) { @@ -301,17 +305,29 @@ public void assertFunction(String projection, Type expectedType, Object expected assertEquals(actual, expected); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#function(String, String...)} + */ + @Deprecated public void assertFunctionString(String projection, Type expectedType, String expected) { Object actual = selectSingleValue(projection, expectedType, runner.getExpressionCompiler()); assertEquals(actual.toString(), expected); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#expression(String)} + */ + @Deprecated public void tryEvaluate(String expression, Type expectedType) { tryEvaluate(expression, expectedType, session); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#expression(String)} + */ + @Deprecated public void tryEvaluate(String expression, Type expectedType, Session session) { selectUniqueValue(expression, expectedType, session, runner.getExpressionCompiler()); @@ -322,6 +338,10 @@ public void tryEvaluateWithAll(String expression, Type expectedType) tryEvaluateWithAll(expression, expectedType, session); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#expression(String)} + */ + @Deprecated public void tryEvaluateWithAll(String expression, Type expectedType, Session session) { executeProjectionWithAll(expression, expectedType, session, runner.getExpressionCompiler()); @@ -783,21 +803,19 @@ private Object interpret(Expression expression, Type expectedType, Session sessi if (javaType == boolean.class) { return type.getBoolean(block, position); } - else if (javaType == long.class) { + if (javaType == long.class) { return type.getLong(block, position); } - else if (javaType == double.class) { + if (javaType == double.class) { return type.getDouble(block, position); } - else if (javaType == Slice.class) { + if (javaType == Slice.class) { return type.getSlice(block, position); } - else if (javaType == Block.class || javaType == Int128.class) { + if (javaType == Block.class || javaType == Int128.class) { return type.getObject(block, position); } - else { - throw new UnsupportedOperationException("not yet implemented"); - } + throw new UnsupportedOperationException("not yet implemented"); }); // convert result from stack type to Type ObjectValue @@ -953,9 +971,7 @@ public ConnectorPageSource createPageSource(Session session, Split split, TableH .build(); return new RecordPageSource(records); } - else { - return new FixedPageSource(ImmutableList.of(SOURCE_PAGE)); - } + return new FixedPageSource(ImmutableList.of(SOURCE_PAGE)); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java index 3e7b0404f83b..6937ddd8a363 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java @@ -13,30 +13,72 @@ */ package io.trino.operator.scalar; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.type.BooleanType.BOOLEAN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestArrayContainsSequence - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testBasic() { - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[1,2])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[3,4])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[5,6])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[1,2,4])", BOOLEAN, false); - assertFunction("contains_sequence(ARRAY [1,2,3,NULL,4,5,6], ARRAY[3,NULL,4])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[1,2,3,4,5,6])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY ['1','2','3'], ARRAY['1','2'])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1.1,2.2,3.3], ARRAY[1.1,2.2])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [ARRAY[1,2],ARRAY[3],ARRAY[4,5]], ARRAY[ARRAY[1,2],ARRAY[3]])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [ARRAY[1,2],ARRAY[3],ARRAY[4,5]], ARRAY[ARRAY[1,2],ARRAY[4]])", BOOLEAN, false); + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[1, 2]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[3, 4]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[5, 6]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[1, 2, 4]")) + .isEqualTo(false); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, NULL, 4, 5, 6]", "ARRAY[3, NULL, 4]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[1, 2, 3, 4, 5, 6]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY['1', '2', '3']", "ARRAY['1', '2']")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1.1, 2.2, 3.3]", "ARRAY[1.1, 2.2]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[ARRAY[1,2], ARRAY[3], ARRAY[4,5]]", "ARRAY[ARRAY[1,2], ARRAY[3]]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[ARRAY[1,2], ARRAY[3], ARRAY[4,5]]", "ARRAY[ARRAY[1,2], ARRAY[4]]")) + .isEqualTo(false); for (int i = 1; i <= 6; i++) { - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[" + i + "])", BOOLEAN, true); + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[%d]".formatted(i))) + .isEqualTo(true); } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java index 011500c6e539..d90fc4bff773 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java @@ -13,64 +13,112 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; import io.trino.spi.type.ArrayType; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.UnknownType.UNKNOWN; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestArrayExceptFunction - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testBasic() { - assertFunction("array_except(ARRAY[1, 5, 3], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 3], ARRAY[5])", new ArrayType(BIGINT), ImmutableList.of(1L, 3L)); - assertFunction("array_except(ARRAY[VARCHAR 'x', 'y', 'z'], ARRAY['x'])", new ArrayType(VARCHAR), ImmutableList.of("y", "z")); - assertFunction("array_except(ARRAY[true, false, null], ARRAY[true])", new ArrayType(BOOLEAN), asList(false, null)); - assertFunction("array_except(ARRAY[1.1E0, 5.4E0, 3.9E0], ARRAY[5, 5.4E0])", new ArrayType(DOUBLE), ImmutableList.of(1.1, 3.9)); + assertThat(assertions.function("array_except", "ARRAY[1, 5, 3]", "ARRAY[3]")) + .matches("ARRAY[1, 5]"); + + assertThat(assertions.function("array_except", "ARRAY[BIGINT '1', 5, 3]", "ARRAY[5]")) + .matches("ARRAY[BIGINT '1', BIGINT '3']"); + + assertThat(assertions.function("array_except", "ARRAY[VARCHAR 'x', 'y', 'z']", "ARRAY['x']")) + .matches("ARRAY[VARCHAR 'y', VARCHAR 'z']"); + + assertThat(assertions.function("array_except", "ARRAY[true, false, null]", "ARRAY[true]")) + .matches("ARRAY[false, null]"); + + assertThat(assertions.function("array_except", "ARRAY[1.1E0, 5.4E0, 3.9E0]", "ARRAY[5, 5.4E0]")) + .matches("ARRAY[1.1E0, 3.9E0]"); } @Test public void testEmpty() { - assertFunction("array_except(ARRAY[], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_except(ARRAY[], ARRAY[1, 3])", new ArrayType(INTEGER), ImmutableList.of()); - assertFunction("array_except(ARRAY[VARCHAR 'abc'], ARRAY[])", new ArrayType(VARCHAR), ImmutableList.of("abc")); + assertThat(assertions.function("array_except", "ARRAY[]", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.function("array_except", "ARRAY[]", "ARRAY[1, 3]")) + .matches("CAST(ARRAY[] AS array(integer))"); + + assertThat(assertions.function("array_except", "ARRAY[VARCHAR 'abc']", "ARRAY[]")) + .matches("ARRAY[VARCHAR 'abc']"); } @Test public void testNull() { - assertFunction("array_except(ARRAY[NULL], NULL)", new ArrayType(UNKNOWN), null); - assertFunction("array_except(NULL, NULL)", new ArrayType(UNKNOWN), null); - assertFunction("array_except(NULL, ARRAY[NULL])", new ArrayType(UNKNOWN), null); - assertFunction("array_except(ARRAY[NULL], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_except(ARRAY[], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_except(ARRAY[NULL], ARRAY[])", new ArrayType(UNKNOWN), singletonList(null)); + assertThat(assertions.function("array_except", "ARRAY[NULL]", "NULL")) + .isNull(new ArrayType(UNKNOWN)); + + assertThat(assertions.function("array_except", "NULL", "NULL")) + .isNull(new ArrayType(UNKNOWN)); + + assertThat(assertions.function("array_except", "NULL", "ARRAY[NULL]")) + .isNull(new ArrayType(UNKNOWN)); + + assertThat(assertions.function("array_except", "ARRAY[NULL]", "ARRAY[NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.function("array_except", "ARRAY[]", "ARRAY[NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.function("array_except", "ARRAY[NULL]", "ARRAY[]")) + .matches("ARRAY[NULL]"); } @Test public void testDuplicates() { - assertFunction("array_except(ARRAY[1, 5, 3, 5, 1], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 5, 3, 3, 3, 1], ARRAY[3, 5])", new ArrayType(BIGINT), ImmutableList.of(1L)); - assertFunction("array_except(ARRAY[VARCHAR 'x', 'x', 'y', 'z'], ARRAY['x', 'y', 'x'])", new ArrayType(VARCHAR), ImmutableList.of("z")); - assertFunction("array_except(ARRAY[true, false, null, true, false, null], ARRAY[true, true, true])", new ArrayType(BOOLEAN), asList(false, null)); + assertThat(assertions.function("array_except", "ARRAY[1, 5, 3, 5, 1]", "ARRAY[3]")) + .matches("ARRAY[1, 5]"); + + assertThat(assertions.function("array_except", "ARRAY[BIGINT '1', 5, 5, 3, 3, 3, 1]", "ARRAY[3, 5]")) + .matches("ARRAY[BIGINT '1']"); + + assertThat(assertions.function("array_except", "ARRAY[VARCHAR 'x', 'x', 'y', 'z']", "ARRAY['x', 'y', 'x']")) + .matches("ARRAY[VARCHAR 'z']"); + + assertThat(assertions.function("array_except", "ARRAY[true, false, null, true, false, null]", "ARRAY[true, true, true]")) + .matches("ARRAY[false, null]"); } @Test public void testNonDistinctNonEqualValues() { - assertFunction("array_except(ARRAY[NaN()], ARRAY[NaN()])", new ArrayType(DOUBLE), ImmutableList.of()); - assertFunction("array_except(ARRAY[1, NaN(), 3], ARRAY[NaN(), 3])", new ArrayType(DOUBLE), ImmutableList.of(1.0)); + assertThat(assertions.function("array_except", "ARRAY[NaN()]", "ARRAY[NaN()]")) + .matches("CAST(ARRAY[] AS array(double))"); + + assertThat(assertions.function("array_except", "ARRAY[1, NaN(), 3]", "ARRAY[NaN(), 3]")) + .matches("ARRAY[1E0]"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java index 3642546486e7..bd20cfb31c42 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java @@ -13,73 +13,142 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; -import io.trino.spi.type.ArrayType; -import org.testng.annotations.Test; - -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.TimestampType.createTimestampType; -import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.type.UnknownType.UNKNOWN; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) public class TestArrayFilterFunction - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testBasic() { - assertFunction("filter(ARRAY [5, 6], x -> x = 5)", new ArrayType(INTEGER), ImmutableList.of(5)); - assertFunction("filter(ARRAY [5 + RANDOM(1), 6 + RANDOM(1)], x -> x = 5)", new ArrayType(INTEGER), ImmutableList.of(5)); - assertFunction("filter(ARRAY [true, false, true, false], x -> nullif(x, false))", new ArrayType(BOOLEAN), ImmutableList.of(true, true)); - assertFunction("filter(ARRAY [true, false, null, true, false, null], x -> not x)", new ArrayType(BOOLEAN), ImmutableList.of(false, false)); - assertFunction( - "filter(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> year(t) = 1111)", - new ArrayType(createTimestampType(9)), - ImmutableList.of(timestamp(9, "1111-05-10 12:34:56.123456789"))); + assertThat(assertions.expression("filter(a, x -> x = 5)") + .binding("a", "ARRAY[5, 6]")) + .matches("ARRAY[5]"); + + assertThat(assertions.expression("filter(a, x -> x = 5)") + .binding("a", "ARRAY[5 + random(1), 6 + random(1)]")) + .matches("ARRAY[5]"); + + assertThat(assertions.expression("filter(a, x -> nullif(x, false))") + .binding("a", "ARRAY[true, false, true, false]")) + .matches("ARRAY[true, true]"); + + assertThat(assertions.expression("filter(a, x -> not x)") + .binding("a", "ARRAY[true, false, null, true, false, null]")) + .matches("ARRAY[false, false]"); + + assertThat(assertions.expression("filter(a, t -> year(t) = 1111)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .matches("ARRAY[TIMESTAMP '1111-05-10 12:34:56.123456789']"); } @Test public void testEmpty() { - assertFunction("filter(ARRAY [], x -> true)", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(ARRAY [], x -> false)", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(ARRAY [], x -> CAST (null AS BOOLEAN))", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(CAST (ARRAY [] AS ARRAY(INTEGER)), x -> true)", new ArrayType(INTEGER), ImmutableList.of()); + assertThat(assertions.expression("filter(a, x -> true)") + .binding("a", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> false)") + .binding("a", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> CAST(null AS boolean))") + .binding("a", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> true)") + .binding("a", "CAST(ARRAY[] AS array(integer))")) + .matches("CAST(ARRAY[] AS array(integer))"); } @Test public void testNull() { - assertFunction("filter(ARRAY [NULL], x -> x IS NULL)", new ArrayType(UNKNOWN), singletonList(null)); - assertFunction("filter(ARRAY [NULL], x -> x IS NOT NULL)", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(ARRAY [CAST (NULL AS INTEGER)], x -> x IS NULL)", new ArrayType(INTEGER), singletonList(null)); - assertFunction("filter(ARRAY [NULL, NULL, NULL], x -> x IS NULL)", new ArrayType(UNKNOWN), asList(null, null, null)); - assertFunction("filter(ARRAY [NULL, NULL, NULL], x -> x IS NOT NULL)", new ArrayType(UNKNOWN), ImmutableList.of()); - - assertFunction("filter(ARRAY [25, 26, NULL], x -> x % 2 = 1 OR x IS NULL)", new ArrayType(INTEGER), asList(25, null)); - assertFunction("filter(ARRAY [25.6E0, 37.3E0, NULL], x -> x < 30.0E0 OR x IS NULL)", new ArrayType(DOUBLE), asList(25.6, null)); - assertFunction("filter(ARRAY [true, false, NULL], x -> not x OR x IS NULL)", new ArrayType(BOOLEAN), asList(false, null)); - assertFunction("filter(ARRAY ['abc', 'def', NULL], x -> substr(x, 1, 1) = 'a' OR x IS NULL)", new ArrayType(createVarcharType(3)), asList("abc", null)); - assertFunction( - "filter(ARRAY [ARRAY ['abc', null, '123'], NULL], x -> x[2] IS NULL OR x IS NULL)", - new ArrayType(new ArrayType(createVarcharType(3))), - asList(asList("abc", null, "123"), null)); + assertThat(assertions.expression("filter(a, x -> x IS NULL)") + .binding("a", "ARRAY[NULL]")) + .matches("ARRAY[NULL]"); + + assertThat(assertions.expression("filter(a, x -> x IS NOT NULL)") + .binding("a", "ARRAY[NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> x IS NULL)") + .binding("a", "ARRAY[CAST(NULL AS integer)]")) + .matches("CAST(ARRAY[NULL] AS array(integer))"); + + assertThat(assertions.expression("filter(a, x -> x IS NULL)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .matches("ARRAY[NULL, NULL, NULL]"); + + assertThat(assertions.expression("filter(a, x -> x IS NOT NULL)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> x % 2 = 1 OR x IS NULL)") + .binding("a", "ARRAY[25, 26, NULL]")) + .matches("ARRAY[25, NULL]"); + + assertThat(assertions.expression("filter(a, x -> x < 30.0E0 OR x IS NULL)") + .binding("a", "ARRAY[25.6E0, 37.3E0, NULL]")) + .matches("ARRAY[25.6E0, NULL]"); + + assertThat(assertions.expression("filter(a, x -> NOT x OR x IS NULL)") + .binding("a", "ARRAY[true, false, NULL]")) + .matches("ARRAY[false, NULL]"); + + assertThat(assertions.expression("filter(a, x -> substr(x, 1, 1) = 'a' OR x IS NULL)") + .binding("a", "ARRAY['abc', 'def', NULL]")) + .matches("ARRAY['abc', NULL]"); + + assertThat(assertions.expression("filter(a, x -> x[2] IS NULL OR x IS NULL)") + .binding("a", "ARRAY[ARRAY['abc', NULL, '123']]")) + .matches("ARRAY[ARRAY['abc', NULL, '123']]"); } @Test public void testTypeCombinations() { - assertFunction("filter(ARRAY [25, 26, 27], x -> x % 2 = 1)", new ArrayType(INTEGER), ImmutableList.of(25, 27)); - assertFunction("filter(ARRAY [25.6E0, 37.3E0, 28.6E0], x -> x < 30.0E0)", new ArrayType(DOUBLE), ImmutableList.of(25.6, 28.6)); - assertFunction("filter(ARRAY [true, false, true], x -> not x)", new ArrayType(BOOLEAN), ImmutableList.of(false)); - assertFunction("filter(ARRAY ['abc', 'def', 'ayz'], x -> substr(x, 1, 1) = 'a')", new ArrayType(createVarcharType(3)), ImmutableList.of("abc", "ayz")); - assertFunction( - "filter(ARRAY [ARRAY ['abc', null, '123'], ARRAY ['def', 'x', '456']], x -> x[2] IS NULL)", - new ArrayType(new ArrayType(createVarcharType(3))), - ImmutableList.of(asList("abc", null, "123"))); + assertThat(assertions.expression("filter(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[25, 26, 27]")) + .matches("ARRAY[25, 27]"); + + assertThat(assertions.expression("filter(a, x -> x < 30.0E0)") + .binding("a", "ARRAY[25.6E0, 37.3E0, 28.6E0]")) + .matches("ARRAY[25.6E0, 28.6E0]"); + + assertThat(assertions.expression("filter(a, x -> NOT x)") + .binding("a", "ARRAY[true, false, true]")) + .matches("ARRAY[false]"); + + assertThat(assertions.expression("filter(a, x -> substr(x, 1, 1) = 'a' OR x IS NULL)") + .binding("a", "ARRAY['abc', 'def', 'ayz']")) + .matches("ARRAY['abc', 'ayz']"); + + assertThat(assertions.expression("filter(a, x -> x[2] IS NULL)") + .binding("a", "ARRAY[ARRAY['abc', NULL, '123'], ARRAY ['def', 'x', '456']]")) + .matches("ARRAY[ARRAY['abc', NULL, '123']]"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java index b10300100cd4..14cb70d41bda 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java @@ -14,34 +14,68 @@ package io.trino.operator.scalar; import com.google.common.base.Joiner; +import io.trino.spi.TrinoException; import io.trino.spi.type.ArrayType; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.StandardErrorCode.TOO_MANY_ARGUMENTS; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static java.util.Collections.nCopies; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestArrayFunctions - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testArrayConstructor() { - tryEvaluateWithAll("array[" + Joiner.on(", ").join(nCopies(254, "rand()")) + "]", new ArrayType(DOUBLE)); - assertInvalidFunction( - "array[" + Joiner.on(", ").join(nCopies(255, "rand()")) + "]", - TOO_MANY_ARGUMENTS, - "Too many arguments for array constructor"); + assertThat(assertions.expression("array[" + Joiner.on(", ").join(nCopies(254, "rand()")) + "]")) + .hasType(new ArrayType(DOUBLE)); + + assertThat(assertions.expression("array[a, b, c]") + .binding("a", "1") + .binding("b", "2") + .binding("c", "3")) + .matches("ARRAY[1, 2, 3]"); + + assertThatThrownBy(() -> assertions.expression("array[" + Joiner.on(", ").join(nCopies(255, "rand()")) + "]").evaluate()) + .isInstanceOf(TrinoException.class) + .hasMessage("Too many arguments for array constructor"); } @Test public void testArrayConcat() { - assertFunction("CONCAT(" + Joiner.on(", ").join(nCopies(127, "array[1]")) + ")", new ArrayType(INTEGER), nCopies(127, 1)); - assertInvalidFunction( - "CONCAT(" + Joiner.on(", ").join(nCopies(128, "array[1]")) + ")", - TOO_MANY_ARGUMENTS, - "line 1:1: Too many arguments for function call concat()"); + assertThat(assertions.expression("CONCAT(" + Joiner.on(", ").join(nCopies(127, "array[1]")) + ")")) + .hasType(new ArrayType(INTEGER)) + .matches("ARRAY[%s]".formatted(Joiner.on(",").join(nCopies(127, 1)))); + + assertThat(assertions.function("concat", "ARRAY[1]", "ARRAY[2]", "ARRAY[3]")) + .matches("ARRAY[1, 2, 3]"); + + assertThatThrownBy(() -> assertions.expression("CONCAT(" + Joiner.on(", ").join(nCopies(128, "array[1]")) + ")").evaluate()) + .isInstanceOf(TrinoException.class) + .hasMessage("line 1:8: Too many arguments for function call concat()"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java index e01f4c37b86e..e21ee971a6b0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java @@ -13,54 +13,163 @@ */ package io.trino.operator.scalar; -import io.trino.spi.type.BooleanType; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) public class TestArrayMatchFunctions - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testAllMatch() { - assertFunction("all_match(ARRAY [5, 7, 9], x -> x % 2 = 1)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [true, false, true], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("all_match(ARRAY ['abc', 'ade', 'afg'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [true, true, NULL], x -> x)", BooleanType.BOOLEAN, null); - assertFunction("all_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("all_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); - assertFunction("all_match(ARRAY [NULL, NULL, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 1)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> month(t) = 5)", BooleanType.BOOLEAN, true); + assertThat(assertions.expression("all_match(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[5, 7, 9]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> x)") + .binding("a", "ARRAY[true, false, true]")) + .isEqualTo(false); + + assertThat(assertions.expression("all_match(a, x -> substr(x, 1, 1) = 'a')") + .binding("a", "ARRAY['abc', 'ade', 'afg']")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> true)") + .binding("a", "ARRAY[]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> x)") + .binding("a", "ARRAY[true, true, NULL]")) + .matches("CAST(NULL AS boolean)"); + + assertThat(assertions.expression("all_match(a, x -> x)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(false); + + assertThat(assertions.expression("all_match(a, x -> x > 1)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("all_match(a, x -> x IS NULL)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> cardinality(x) > 1)") + .binding("a", "ARRAY[MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, t -> month(t) = 5)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .isEqualTo(true); } @Test public void testAnyMatch() { - assertFunction("any_match(ARRAY [5, 8, 10], x -> x % 2 = 1)", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [false, false, false], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("any_match(ARRAY ['abc', 'def', 'ghi'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, false); - assertFunction("any_match(ARRAY [false, false, NULL], x -> x)", BooleanType.BOOLEAN, null); - assertFunction("any_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); - assertFunction("any_match(ARRAY [true, false, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 4)", BooleanType.BOOLEAN, false); - assertFunction("any_match(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> year(t) = 2020)", BooleanType.BOOLEAN, true); + assertThat(assertions.expression("any_match(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[5, 8, 10]")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> x)") + .binding("a", "ARRAY[false, false, false]")) + .isEqualTo(false); + + assertThat(assertions.expression("any_match(a, x -> substr(x, 1, 1) = 'a')") + .binding("a", "ARRAY['abc', 'def', 'ghi']")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> true)") + .binding("a", "ARRAY[]")) + .isEqualTo(false); + + assertThat(assertions.expression("any_match(a, x -> x)") + .binding("a", "ARRAY[false, false, NULL]")) + .matches("CAST(NULL AS boolean)"); + + assertThat(assertions.expression("any_match(a, x -> x)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> x > 1)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("any_match(a, x -> x IS NULL)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> cardinality(x) > 4)") + .binding("a", "ARRAY[MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])]")) + .isEqualTo(false); + + assertThat(assertions.expression("any_match(a, t -> year(t) = 2020)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .isEqualTo(true); } @Test public void testNoneMatch() { - assertFunction("none_match(ARRAY [5, 8, 10], x -> x % 2 = 1)", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [false, false, false], x -> x)", BooleanType.BOOLEAN, true); - assertFunction("none_match(ARRAY ['abc', 'def', 'ghi'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, true); - assertFunction("none_match(ARRAY [false, false, NULL], x -> x)", BooleanType.BOOLEAN, null); - assertFunction("none_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); - assertFunction("none_match(ARRAY [true, false, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 4)", BooleanType.BOOLEAN, true); - assertFunction("none_match(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> month(t) = 10)", BooleanType.BOOLEAN, true); + assertThat(assertions.expression("none_match(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[5, 8, 10]")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> x)") + .binding("a", "ARRAY[false, false, false]")) + .isEqualTo(true); + + assertThat(assertions.expression("none_match(a, x -> substr(x, 1, 1) = 'a')") + .binding("a", "ARRAY['abc', 'def', 'ghi']")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> true)") + .binding("a", "ARRAY[]")) + .isEqualTo(true); + + assertThat(assertions.expression("none_match(a, x -> x)") + .binding("a", "ARRAY[false, false, NULL]")) + .matches("CAST(NULL AS boolean)"); + + assertThat(assertions.expression("none_match(a, x -> x)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> x > 1)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("none_match(a, x -> x IS NULL)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> cardinality(x) > 4)") + .binding("a", "ARRAY[MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])]")) + .isEqualTo(true); + + assertThat(assertions.expression("none_match(a, t -> month(t) = 10)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .isEqualTo(true); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java index 9b880e495dc9..a125133cd2cd 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java @@ -13,7 +13,7 @@ */ package io.trino.operator.scalar; -import io.airlift.slice.Slice; +import io.trino.sql.query.QueryAssertions; import org.testng.annotations.Test; import static io.airlift.slice.Slices.utf8Slice; @@ -26,34 +26,33 @@ import static io.trino.operator.scalar.ColorFunctions.render; import static io.trino.operator.scalar.ColorFunctions.rgb; import static io.trino.spi.function.OperatorType.INDETERMINATE; -import static io.trino.spi.type.BooleanType.BOOLEAN; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; public class TestColorFunctions - extends AbstractTestFunctions { @Test public void testParseRgb() { - assertEquals(parseRgb(toSlice("#000")), 0x00_00_00); - assertEquals(parseRgb(toSlice("#FFF")), 0xFF_FF_FF); - assertEquals(parseRgb(toSlice("#F00")), 0xFF_00_00); - assertEquals(parseRgb(toSlice("#0F0")), 0x00_FF_00); - assertEquals(parseRgb(toSlice("#00F")), 0x00_00_FF); - assertEquals(parseRgb(toSlice("#700")), 0x77_00_00); - assertEquals(parseRgb(toSlice("#070")), 0x00_77_00); - assertEquals(parseRgb(toSlice("#007")), 0x00_00_77); - - assertEquals(parseRgb(toSlice("#cde")), 0xCC_DD_EE); + assertEquals(parseRgb(utf8Slice("#000")), 0x00_00_00); + assertEquals(parseRgb(utf8Slice("#FFF")), 0xFF_FF_FF); + assertEquals(parseRgb(utf8Slice("#F00")), 0xFF_00_00); + assertEquals(parseRgb(utf8Slice("#0F0")), 0x00_FF_00); + assertEquals(parseRgb(utf8Slice("#00F")), 0x00_00_FF); + assertEquals(parseRgb(utf8Slice("#700")), 0x77_00_00); + assertEquals(parseRgb(utf8Slice("#070")), 0x00_77_00); + assertEquals(parseRgb(utf8Slice("#007")), 0x00_00_77); + + assertEquals(parseRgb(utf8Slice("#cde")), 0xCC_DD_EE); } @Test public void testGetComponent() { - assertEquals(getRed(parseRgb(toSlice("#789"))), 0x77); - assertEquals(getGreen(parseRgb(toSlice("#789"))), 0x88); - assertEquals(getBlue(parseRgb(toSlice("#789"))), 0x99); + assertEquals(getRed(parseRgb(utf8Slice("#789"))), 0x77); + assertEquals(getGreen(parseRgb(utf8Slice("#789"))), 0x88); + assertEquals(getBlue(parseRgb(utf8Slice("#789"))), 0x99); } @Test @@ -67,102 +66,102 @@ public void testToRgb() @Test public void testColor() { - assertEquals(color(toSlice("black")), -1); - assertEquals(color(toSlice("red")), -2); - assertEquals(color(toSlice("green")), -3); - assertEquals(color(toSlice("yellow")), -4); - assertEquals(color(toSlice("blue")), -5); - assertEquals(color(toSlice("magenta")), -6); - assertEquals(color(toSlice("cyan")), -7); - assertEquals(color(toSlice("white")), -8); - - assertEquals(color(toSlice("#f00")), 0xFF_00_00); - assertEquals(color(toSlice("#0f0")), 0x00_FF_00); - assertEquals(color(toSlice("#00f")), 0x00_00_FF); + assertEquals(color(utf8Slice("black")), -1); + assertEquals(color(utf8Slice("red")), -2); + assertEquals(color(utf8Slice("green")), -3); + assertEquals(color(utf8Slice("yellow")), -4); + assertEquals(color(utf8Slice("blue")), -5); + assertEquals(color(utf8Slice("magenta")), -6); + assertEquals(color(utf8Slice("cyan")), -7); + assertEquals(color(utf8Slice("white")), -8); + + assertEquals(color(utf8Slice("#f00")), 0xFF_00_00); + assertEquals(color(utf8Slice("#0f0")), 0x00_FF_00); + assertEquals(color(utf8Slice("#00f")), 0x00_00_FF); } @Test public void testBar() { - assertEquals(bar(0.6, 5, color(toSlice("#f0f")), color(toSlice("#00f"))), - toSlice("\u001B[38;5;201m\u2588\u001B[38;5;165m\u2588\u001B[38;5;129m\u2588\u001B[0m ")); + assertEquals(bar(0.6, 5, color(utf8Slice("#f0f")), color(utf8Slice("#00f"))), + utf8Slice("\u001B[38;5;201m\u2588\u001B[38;5;165m\u2588\u001B[38;5;129m\u2588\u001B[0m ")); - assertEquals(bar(1, 10, color(toSlice("#f00")), color(toSlice("#0f0"))), - toSlice("\u001B[38;5;196m\u2588\u001B[38;5;202m\u2588\u001B[38;5;208m\u2588\u001B[38;5;214m\u2588\u001B[38;5;226m\u2588\u001B[38;5;226m\u2588\u001B[38;5;154m\u2588\u001B[38;5;118m\u2588\u001B[38;5;82m\u2588\u001B[38;5;46m\u2588\u001B[0m")); + assertEquals(bar(1, 10, color(utf8Slice("#f00")), color(utf8Slice("#0f0"))), + utf8Slice("\u001B[38;5;196m\u2588\u001B[38;5;202m\u2588\u001B[38;5;208m\u2588\u001B[38;5;214m\u2588\u001B[38;5;226m\u2588\u001B[38;5;226m\u2588\u001B[38;5;154m\u2588\u001B[38;5;118m\u2588\u001B[38;5;82m\u2588\u001B[38;5;46m\u2588\u001B[0m")); - assertEquals(bar(0.6, 5, color(toSlice("#f0f")), color(toSlice("#00f"))), - toSlice("\u001B[38;5;201m\u2588\u001B[38;5;165m\u2588\u001B[38;5;129m\u2588\u001B[0m ")); + assertEquals(bar(0.6, 5, color(utf8Slice("#f0f")), color(utf8Slice("#00f"))), + utf8Slice("\u001B[38;5;201m\u2588\u001B[38;5;165m\u2588\u001B[38;5;129m\u2588\u001B[0m ")); } @Test public void testRenderBoolean() { - assertEquals(render(true), toSlice("\u001b[38;5;2m✓\u001b[0m")); - assertEquals(render(false), toSlice("\u001b[38;5;1m✗\u001b[0m")); + assertEquals(render(true), utf8Slice("\u001b[38;5;2m✓\u001b[0m")); + assertEquals(render(false), utf8Slice("\u001b[38;5;1m✗\u001b[0m")); } @Test public void testRenderString() { - assertEquals(render(toSlice("hello"), color(toSlice("red"))), toSlice("\u001b[38;5;1mhello\u001b[0m")); + assertEquals(render(utf8Slice("hello"), color(utf8Slice("red"))), utf8Slice("\u001b[38;5;1mhello\u001b[0m")); - assertEquals(render(toSlice("hello"), color(toSlice("#f00"))), toSlice("\u001b[38;5;196mhello\u001b[0m")); - assertEquals(render(toSlice("hello"), color(toSlice("#0f0"))), toSlice("\u001b[38;5;46mhello\u001b[0m")); - assertEquals(render(toSlice("hello"), color(toSlice("#00f"))), toSlice("\u001b[38;5;21mhello\u001b[0m")); + assertEquals(render(utf8Slice("hello"), color(utf8Slice("#f00"))), utf8Slice("\u001b[38;5;196mhello\u001b[0m")); + assertEquals(render(utf8Slice("hello"), color(utf8Slice("#0f0"))), utf8Slice("\u001b[38;5;46mhello\u001b[0m")); + assertEquals(render(utf8Slice("hello"), color(utf8Slice("#00f"))), utf8Slice("\u001b[38;5;21mhello\u001b[0m")); } @Test public void testRenderLong() { - assertEquals(render(1234, color(toSlice("red"))), toSlice("\u001b[38;5;1m1234\u001b[0m")); + assertEquals(render(1234, color(utf8Slice("red"))), utf8Slice("\u001b[38;5;1m1234\u001b[0m")); - assertEquals(render(1234, color(toSlice("#f00"))), toSlice("\u001b[38;5;196m1234\u001b[0m")); - assertEquals(render(1234, color(toSlice("#0f0"))), toSlice("\u001b[38;5;46m1234\u001b[0m")); - assertEquals(render(1234, color(toSlice("#00f"))), toSlice("\u001b[38;5;21m1234\u001b[0m")); + assertEquals(render(1234, color(utf8Slice("#f00"))), utf8Slice("\u001b[38;5;196m1234\u001b[0m")); + assertEquals(render(1234, color(utf8Slice("#0f0"))), utf8Slice("\u001b[38;5;46m1234\u001b[0m")); + assertEquals(render(1234, color(utf8Slice("#00f"))), utf8Slice("\u001b[38;5;21m1234\u001b[0m")); } @Test public void testRenderDouble() { - assertEquals(render(1234.5678, color(toSlice("red"))), toSlice("\u001b[38;5;1m1234.5678\u001b[0m")); - assertEquals(render(1234.5678f, color(toSlice("red"))), toSlice(format("\u001b[38;5;1m%s\u001b[0m", (double) 1234.5678f))); + assertEquals(render(1234.5678, color(utf8Slice("red"))), utf8Slice("\u001b[38;5;1m1234.5678\u001b[0m")); + assertEquals(render(1234.5678f, color(utf8Slice("red"))), utf8Slice(format("\u001b[38;5;1m%s\u001b[0m", (double) 1234.5678f))); - assertEquals(render(1234.5678, color(toSlice("#f00"))), toSlice("\u001b[38;5;196m1234.5678\u001b[0m")); - assertEquals(render(1234.5678, color(toSlice("#0f0"))), toSlice("\u001b[38;5;46m1234.5678\u001b[0m")); - assertEquals(render(1234.5678, color(toSlice("#00f"))), toSlice("\u001b[38;5;21m1234.5678\u001b[0m")); + assertEquals(render(1234.5678, color(utf8Slice("#f00"))), utf8Slice("\u001b[38;5;196m1234.5678\u001b[0m")); + assertEquals(render(1234.5678, color(utf8Slice("#0f0"))), utf8Slice("\u001b[38;5;46m1234.5678\u001b[0m")); + assertEquals(render(1234.5678, color(utf8Slice("#00f"))), utf8Slice("\u001b[38;5;21m1234.5678\u001b[0m")); } @Test public void testInterpolate() { - assertEquals(color(0, 0, 255, color(toSlice("#000")), color(toSlice("#fff"))), 0x00_00_00); - assertEquals(color(0.0f, 0.0f, 255.0f, color(toSlice("#000")), color(toSlice("#fff"))), 0x00_00_00); - assertEquals(color(128, 0, 255, color(toSlice("#000")), color(toSlice("#fff"))), 0x80_80_80); - assertEquals(color(255, 0, 255, color(toSlice("#000")), color(toSlice("#fff"))), 0xFF_FF_FF); + assertEquals(color(0, 0, 255, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x00_00_00); + assertEquals(color(0.0f, 0.0f, 255.0f, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x00_00_00); + assertEquals(color(128, 0, 255, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x80_80_80); + assertEquals(color(255, 0, 255, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0xFF_FF_FF); assertEquals(color(-1, 42, 52, rgb(0xFF, 0, 0), rgb(0xFF, 0xFF, 0)), 0xFF_00_00); assertEquals(color(47, 42, 52, rgb(0xFF, 0, 0), rgb(0xFF, 0xFF, 0)), 0xFF_80_00); assertEquals(color(142, 42, 52, rgb(0xFF, 0, 0), rgb(0xFF, 0xFF, 0)), 0xFF_FF_00); - assertEquals(color(-42, color(toSlice("#000")), color(toSlice("#fff"))), 0x00_00_00); - assertEquals(color(0.0, color(toSlice("#000")), color(toSlice("#fff"))), 0x00_00_00); - assertEquals(color(0.5, color(toSlice("#000")), color(toSlice("#fff"))), 0x80_80_80); - assertEquals(color(1.0, color(toSlice("#000")), color(toSlice("#fff"))), 0xFF_FF_FF); - assertEquals(color(42, color(toSlice("#000")), color(toSlice("#fff"))), 0xFF_FF_FF); - assertEquals(color(1.0f, color(toSlice("#000")), color(toSlice("#fff"))), 0xFF_FF_FF); - assertEquals(color(-0.0f, color(toSlice("#000")), color(toSlice("#fff"))), 0x00_00_00); - assertEquals(color(0.0f, color(toSlice("#000")), color(toSlice("#fff"))), 0x00_00_00); + assertEquals(color(-42, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x00_00_00); + assertEquals(color(0.0, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x00_00_00); + assertEquals(color(0.5, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x80_80_80); + assertEquals(color(1.0, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0xFF_FF_FF); + assertEquals(color(42, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0xFF_FF_FF); + assertEquals(color(1.0f, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0xFF_FF_FF); + assertEquals(color(-0.0f, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x00_00_00); + assertEquals(color(0.0f, color(utf8Slice("#000")), color(utf8Slice("#fff"))), 0x00_00_00); } @Test public void testIndeterminate() { - assertOperator(INDETERMINATE, "color(null)", BOOLEAN, true); - assertOperator(INDETERMINATE, "color('black')", BOOLEAN, false); - } + try (QueryAssertions assertions = new QueryAssertions()) { + assertThat(assertions.operator(INDETERMINATE, "color(null)")) + .isEqualTo(true); - private static Slice toSlice(String string) - { - return utf8Slice(string); + assertThat(assertions.operator(INDETERMINATE, "color('black')")) + .isEqualTo(false); + } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java index 92c8cdf8527d..718a17863bba 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java @@ -13,145 +13,519 @@ */ package io.trino.operator.scalar; -import org.testng.annotations.Test; +import io.trino.spi.TrinoException; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DecimalType.createDecimalType; -import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.SqlDecimal.decimal; -import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestConditions - extends AbstractTestFunctions { - @Test - public void testLike() - { - assertFunction("'_monkey_' like 'X_monkeyX_' escape 'X'", BOOLEAN, true); - - assertFunction("'monkey' like 'monkey'", BOOLEAN, true); - assertFunction("'monkey' like 'mon%'", BOOLEAN, true); - assertFunction("'monkey' like 'mon_ey'", BOOLEAN, true); - assertFunction("'monkey' like 'm____y'", BOOLEAN, true); - - assertFunction("'monkey' like 'dain'", BOOLEAN, false); - assertFunction("'monkey' like 'key'", BOOLEAN, false); - - assertFunction("'_monkey_' like '\\_monkey\\_'", BOOLEAN, false); - assertFunction("'_monkey_' like 'X_monkeyX_' escape 'X'", BOOLEAN, true); - - assertFunction("null like 'monkey'", BOOLEAN, null); - assertFunction("'monkey' like null", BOOLEAN, null); - assertFunction("'monkey' like 'monkey' escape null", BOOLEAN, null); + private QueryAssertions assertions; - assertFunction("'_monkey_' not like 'X_monkeyX_' escape 'X'", BOOLEAN, false); - - assertFunction("'monkey' not like 'monkey'", BOOLEAN, false); - assertFunction("'monkey' not like 'mon%'", BOOLEAN, false); - assertFunction("'monkey' not like 'mon_ey'", BOOLEAN, false); - assertFunction("'monkey' not like 'm____y'", BOOLEAN, false); - - assertFunction("'monkey' not like 'dain'", BOOLEAN, true); - assertFunction("'monkey' not like 'key'", BOOLEAN, true); - - assertFunction("'_monkey_' not like '\\_monkey\\_'", BOOLEAN, true); - assertFunction("'_monkey_' not like 'X_monkeyX_' escape 'X'", BOOLEAN, false); + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } - assertFunction("null not like 'monkey'", BOOLEAN, null); - assertFunction("'monkey' not like null", BOOLEAN, null); - assertFunction("'monkey' not like 'monkey' escape null", BOOLEAN, null); + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } - assertInvalidFunction("'monkey' like 'monkey' escape 'foo'", "Escape string must be a single character"); + @Test + public void testLike() + { + // like + assertThat(assertions.expression("a like 'X_monkeyX_' escape 'X'") + .binding("a", "'_monkey_'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'monkey'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'mon%'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like '%key'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'm____y'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'lion'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a like 'monkey'") + .binding("a", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a like null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a like 'monkey' escape null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + // not like + assertThat(assertions.expression("a not like 'X_monkeyX_' escape 'X'") + .binding("a", "'_monkey_'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'monkey'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'mon%'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like '%key'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'm____y'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'lion'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a not like 'monkey'") + .binding("a", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a not like null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a not like 'monkey' escape null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + assertThatThrownBy(() -> assertions.expression("a like 'monkey' escape 'foo'") + .binding("a", "'monkey'") + .evaluate()) + .isInstanceOf(TrinoException.class) + .hasMessage("Escape string must be a single character"); } @Test public void testDistinctFrom() { - assertFunction("NULL IS DISTINCT FROM NULL", BOOLEAN, false); - assertFunction("NULL IS DISTINCT FROM 1", BOOLEAN, true); - assertFunction("1 IS DISTINCT FROM NULL", BOOLEAN, true); - assertFunction("1 IS DISTINCT FROM 1", BOOLEAN, false); - assertFunction("1 IS DISTINCT FROM 2", BOOLEAN, true); - - assertFunction("NULL IS NOT DISTINCT FROM NULL", BOOLEAN, true); - assertFunction("NULL IS NOT DISTINCT FROM 1", BOOLEAN, false); - assertFunction("1 IS NOT DISTINCT FROM NULL", BOOLEAN, false); - assertFunction("1 IS NOT DISTINCT FROM 1", BOOLEAN, true); - assertFunction("1 IS NOT DISTINCT FROM 2", BOOLEAN, false); + // distinct from + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "null") + .binding("b", "null")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "null") + .binding("b", "1")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "1") + .binding("b", "null")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "1") + .binding("b", "1")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "1") + .binding("b", "2")) + .isEqualTo(true); + + // not distinct from + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "null") + .binding("b", "null")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "null") + .binding("b", "1")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "1") + .binding("b", "null")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "1") + .binding("b", "1")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "1") + .binding("b", "2")) + .isEqualTo(false); } @Test public void testBetween() { - assertFunction("3 between 2 and 4", BOOLEAN, true); - assertFunction("3 between 3 and 3", BOOLEAN, true); - assertFunction("3 between 2 and 3", BOOLEAN, true); - assertFunction("3 between 3 and 4", BOOLEAN, true); - assertFunction("3 between 4 and 2", BOOLEAN, false); - assertFunction("2 between 3 and 4", BOOLEAN, false); - assertFunction("5 between 3 and 4", BOOLEAN, false); - assertFunction("null between 2 and 4", BOOLEAN, null); - assertFunction("3 between null and 4", BOOLEAN, null); - assertFunction("3 between 2 and null", BOOLEAN, null); - - assertFunction("3 between 3 and 4000000000", BOOLEAN, true); - assertFunction("5 between 3 and 4000000000", BOOLEAN, true); - assertFunction("3 between BIGINT '3' and 4", BOOLEAN, true); - assertFunction("BIGINT '3' between 3 and 4", BOOLEAN, true); - - assertFunction("'c' between 'b' and 'd'", BOOLEAN, true); - assertFunction("'c' between 'c' and 'c'", BOOLEAN, true); - assertFunction("'c' between 'b' and 'c'", BOOLEAN, true); - assertFunction("'c' between 'c' and 'd'", BOOLEAN, true); - assertFunction("'c' between 'd' and 'b'", BOOLEAN, false); - assertFunction("'b' between 'c' and 'd'", BOOLEAN, false); - assertFunction("'e' between 'c' and 'd'", BOOLEAN, false); - assertFunction("null between 'b' and 'd'", BOOLEAN, null); - assertFunction("'c' between null and 'd'", BOOLEAN, null); - assertFunction("'c' between 'b' and null", BOOLEAN, null); - - assertFunction("3 not between 2 and 4", BOOLEAN, false); - assertFunction("3 not between 3 and 3", BOOLEAN, false); - assertFunction("3 not between 2 and 3", BOOLEAN, false); - assertFunction("3 not between 3 and 4", BOOLEAN, false); - assertFunction("3 not between 4 and 2", BOOLEAN, true); - assertFunction("2 not between 3 and 4", BOOLEAN, true); - assertFunction("5 not between 3 and 4", BOOLEAN, true); - assertFunction("null not between 2 and 4", BOOLEAN, null); - assertFunction("3 not between null and 4", BOOLEAN, null); - assertFunction("3 not between 2 and null", BOOLEAN, null); - - assertFunction("'c' not between 'b' and 'd'", BOOLEAN, false); - assertFunction("'c' not between 'c' and 'c'", BOOLEAN, false); - assertFunction("'c' not between 'b' and 'c'", BOOLEAN, false); - assertFunction("'c' not between 'c' and 'd'", BOOLEAN, false); - assertFunction("'c' not between 'd' and 'b'", BOOLEAN, true); - assertFunction("'b' not between 'c' and 'd'", BOOLEAN, true); - assertFunction("'e' not between 'c' and 'd'", BOOLEAN, true); - assertFunction("null not between 'b' and 'd'", BOOLEAN, null); - assertFunction("'c' not between null and 'd'", BOOLEAN, null); - assertFunction("'c' not between 'b' and null", BOOLEAN, null); + // between + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "4") + .binding("high", "2")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "2") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "null") + .binding("low", "3") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "null") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "5") + .binding("low", "BIGINT '3'") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "BIGINT '3'") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'b'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'c'")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'c'")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'d'") + .binding("high", "'b'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'b'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'e'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "null") + .binding("low", "'b'") + .binding("high", "'d'")) + .matches("CAST(null AS boolean)"); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "null") + .binding("high", "'d'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "null")) + .isNull(BOOLEAN); + + // not between + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "4") + .binding("high", "2")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "2") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "null") + .binding("low", "3") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "null") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "5") + .binding("low", "BIGINT '3'") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "BIGINT '3'") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'b'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'c'")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'c'")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'d'") + .binding("high", "'b'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'b'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'e'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "null") + .binding("low", "'b'") + .binding("high", "'d'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "null") + .binding("high", "'d'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "null")) + .isNull(BOOLEAN); } @Test public void testIn() { - assertFunction("3 in (2, 4, 3, 5)", BOOLEAN, true); - assertFunction("3 not in (2, 4, 3, 5)", BOOLEAN, false); - assertFunction("3 in (2, 4, 9, 5)", BOOLEAN, false); - assertFunction("3 in (2, null, 3, 5)", BOOLEAN, true); + assertThat(assertions.expression("value in (2, 4, 3, 5)") + .binding("value", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value not in (2, 4, 3, 5)") + .binding("value", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value in (2, 4, 9, 5)") + .binding("value", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value in (2, null, 3, 5)") + .binding("value", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value in ('bar', 'baz', 'foo', 'blah')") + .binding("value", "'foo'")) + .isEqualTo(true); + + assertThat(assertions.expression("value in ('bar', 'baz', 'buz', 'blah')") + .binding("value", "'foo'")) + .isEqualTo(false); + + assertThat(assertions.expression("value in ('bar', null, 'foo', 'blah')") + .binding("value", "'foo'")) + .isEqualTo(true); + + assertThat(assertions.expression("value in (2, null, 3, 5)") + .binding("value", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value in (2, null)") + .binding("value", "3")) + .isNull(BOOLEAN); - assertFunction("'foo' in ('bar', 'baz', 'foo', 'blah')", BOOLEAN, true); - assertFunction("'foo' in ('bar', 'baz', 'buz', 'blah')", BOOLEAN, false); - assertFunction("'foo' in ('bar', null, 'foo', 'blah')", BOOLEAN, true); + assertThat(assertions.expression("value not in (2, null, 3, 5)") + .binding("value", "null")) + .isNull(BOOLEAN); - assertFunction("(null in (2, null, 3, 5)) is null", BOOLEAN, true); - assertFunction("(3 in (2, null)) is null", BOOLEAN, true); - assertFunction("(null not in (2, null, 3, 5)) is null", BOOLEAN, true); - assertFunction("(3 not in (2, null)) is null", BOOLEAN, true); + assertThat(assertions.expression("value not in (2, null)") + .binding("value", "3")) + .isNull(BOOLEAN); // Because of the failing in-list item 5 / 0, the in-predicate cannot be simplified. // It is instead processed with the use of generated code which applies the short-circuit @@ -162,296 +536,376 @@ public void testIn() @Test public void testSearchCase() { - assertFunction("case " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when true then BIGINT '33' " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1 " + - "else 33 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when false then 10000000000 " + - "else 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1 " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when false then BIGINT '1' " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 10000000000 " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1 " + - "end", - INTEGER, - null); - - assertFunction("case " + - "when true then null " + - "else 'foo' " + - "end", - createVarcharType(3), - null); - - assertFunction("case " + - "when null then 1 " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when null then 10000000000 " + - "when true then 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1.0E0 " + - "when true then 33 " + - "end", - DOUBLE, - 33.0); - - assertDecimalFunction("case " + - "when false then DECIMAL '2.2' " + - "when true then DECIMAL '2.2' " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case " + - "when false then DECIMAL '1234567890.0987654321' " + - "when true then DECIMAL '3.3' " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case " + - "when false then 1 " + - "when true then DECIMAL '2.2' " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertDecimalFunction("case " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case " + - "when false then 1234567890.0987654321 " + - "when true then 3.3 " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case " + - "when false then 1 " + - "when true then 2.2 " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertFunction("case " + - "when false then DECIMAL '1.1' " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); - - assertFunction("case " + - "when false then 1.1 " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); + assertThat(assertions.expression(""" + case + when value then 33 + end + """) + .binding("value", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when value then BIGINT '33' + end + """) + .binding("value", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when value then 1 + else 33 + end + """) + .binding("value", "false")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when value then 10000000000 + else 33 + end + """) + .binding("value", "false")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when condition1 then 1 + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when condition1 then BIGINT '1' + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when condition1 then 10000000000 + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when value then 1 + end + """) + .binding("value", "false")) + .matches("CAST(null AS integer)"); + + assertThat(assertions.expression(""" + case + when value then null + else 'foo' + end + """) + .binding("value", "true")) + .isNull(createVarcharType(3)); + + assertThat(assertions.expression(""" + case + when condition1 then 1 + when condition2 then 33 + end + """) + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when condition1 then 10000000000 + when condition2 then 33 + end + """) + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when condition1 then 1.0E0 + when condition2 then 33 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); + + assertThat(assertions.expression(""" + case + when condition1 then 2.2 + when condition2 then 2.2 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .hasType(createDecimalType(2, 1)) + .matches("2.2"); + + assertThat(assertions.expression(""" + case + when condition1 then 1234567890.0987654321 + when condition2 then 3.3 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(3.3 AS decimal(20, 10))"); + + assertThat(assertions.expression(""" + case + when condition1 then 1 + when condition2 then 2.2 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(2.2 AS decimal(11, 1))"); + + assertThat(assertions.expression(""" + case + when condition1 then 1.1 + when condition2 then 33E0 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); } @Test public void testSimpleCase() { - assertFunction("case true " + - "when true then cast(null as varchar) " + - "else 'foo' " + - "end", - VARCHAR, - null); - - assertFunction("case true " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when true then BIGINT '33' " + - "end", - BIGINT, - 33L); - - assertFunction("case true " + - "when false then 1 " + - "else 33 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when false then 10000000000 " + - "else 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case true " + - "when false then 1 " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when false then 1 " + - "end", - INTEGER, - null); - - assertFunction("case true " + - "when true then null " + - "else 'foo' " + - "end", - createVarcharType(3), - null); - - assertFunction("case true " + - "when null then 10000000000 " + - "when true then 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case true " + - "when null then 1 " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case null " + - "when true then 1 " + - "else 33 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when false then 1.0E0 " + - "when true then 33 " + - "end", - DOUBLE, - 33.0); - - assertDecimalFunction("case true " + - "when false then DECIMAL '2.2' " + - "when true then DECIMAL '2.2' " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case true " + - "when false then DECIMAL '1234567890.0987654321' " + - "when true then DECIMAL '3.3' " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case true " + - "when false then 1 " + - "when true then DECIMAL '2.2' " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertFunction("case true " + - "when false then DECIMAL '1.1' " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); - - assertDecimalFunction("case true " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case true " + - "when false then 1234567890.0987654321 " + - "when true then 3.3 " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case true " + - "when false then 1 " + - "when true then 2.2 " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertFunction("case true " + - "when false then 1.1 " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); + assertThat(assertions.expression(""" + case value + when condition then CAST(null AS varchar) + else 'foo' + end + """) + .binding("value", "true") + .binding("condition", "true")) + .matches("CAST(null AS varchar)"); + + assertThat(assertions.expression(""" + case value + when condition then 33 + end + """) + .binding("value", "true") + .binding("condition", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then BIGINT '33' + end + """) + .binding("value", "true") + .binding("condition", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case value + when condition then 1 + else 33 + end + """) + .binding("value", "true") + .binding("condition", "false")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then 10000000000 + else 33 + end + """) + .binding("value", "true") + .binding("condition", "false")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then 1 + end + """) + .binding("value", "true") + .binding("condition", "false")) + .isNull(INTEGER); + + assertThat(assertions.expression(""" + case value + when condition then null + else 'foo' + end + """) + .binding("value", "true") + .binding("condition", "true")) + .isNull(createVarcharType(3)); + + assertThat(assertions.expression(""" + case value + when condition1 then 10000000000 + when condition2 then 33 + end + """) + .binding("value", "true") + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 33 + end + """) + .binding("value", "true") + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then 1 + else 33 + end + """) + .binding("value", "null") + .binding("condition", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1E0 + when condition2 then 33 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); + + assertThat(assertions.expression(""" + case value + when condition1 then 2.2 + when condition2 then 2.2 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("2.2"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1234567890.0987654321 + when condition2 then 3.3 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(3.3 AS decimal(20, 10))"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 2.2 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(2.2 AS decimal(11, 1))"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1.1 + when condition2 then 33E0 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); + + assertThat(assertions.expression(""" + case value + when condition1 then result1 + when condition2 then result2 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("result1", "1.1") + .binding("condition2", "true") + .binding("result2", "33.0E0")) + .matches("33.0E0"); } @Test public void testSimpleCaseWithCoercions() { - assertFunction("case 8 " + - "when double '76.1' then 1 " + - "when real '8.1' then 2 " + - "end", - INTEGER, - null); - - assertFunction("case 8 " + - "when 9 then 1 " + - "when cast(null as decimal) then 2 " + - "end", - INTEGER, - null); + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 2 + end + """) + .binding("value", "8") + .binding("condition1", "double '76.1'") + .binding("condition2", "real '8.1'")) + .isNull(INTEGER); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 2 + end + """) + .binding("value", "8") + .binding("condition1", "9") + .binding("condition2", "cast(NULL as decimal)")) + .isNull(INTEGER); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java index 288ec7856f03..7a27180c0427 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java @@ -13,46 +13,114 @@ */ package io.trino.operator.scalar; -import io.trino.spi.type.Type; -import org.testng.annotations.Test; +import io.trino.spi.type.DecimalType; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.SqlDecimal.decimal; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestDataSizeFunctions - extends AbstractTestFunctions { - private static final Type DECIMAL = createDecimalType(38, 0); + private static final DecimalType DECIMAL = createDecimalType(38, 0); + + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } @Test public void testParseDataSize() { - assertFunction("parse_data_size('0B')", DECIMAL, decimal("0", createDecimalType(38))); - assertFunction("parse_data_size('1B')", DECIMAL, decimal("1", createDecimalType(38))); - assertFunction("parse_data_size('1.2B')", DECIMAL, decimal("1", createDecimalType(38))); - assertFunction("parse_data_size('1.9B')", DECIMAL, decimal("1", createDecimalType(38))); - assertFunction("parse_data_size('2.2kB')", DECIMAL, decimal("2252", createDecimalType(38))); - assertFunction("parse_data_size('2.23kB')", DECIMAL, decimal("2283", createDecimalType(38))); - assertFunction("parse_data_size('2.23kB')", DECIMAL, decimal("2283", createDecimalType(38))); - assertFunction("parse_data_size('2.234kB')", DECIMAL, decimal("2287", createDecimalType(38))); - assertFunction("parse_data_size('3MB')", DECIMAL, decimal("3145728", createDecimalType(38))); - assertFunction("parse_data_size('4GB')", DECIMAL, decimal("4294967296", createDecimalType(38))); - assertFunction("parse_data_size('4TB')", DECIMAL, decimal("4398046511104", createDecimalType(38))); - assertFunction("parse_data_size('5PB')", DECIMAL, decimal("5629499534213120", createDecimalType(38))); - assertFunction("parse_data_size('6EB')", DECIMAL, decimal("6917529027641081856", createDecimalType(38))); - assertFunction("parse_data_size('7ZB')", DECIMAL, decimal("8264141345021879123968", createDecimalType(38))); - assertFunction("parse_data_size('8YB')", DECIMAL, decimal("9671406556917033397649408", createDecimalType(38))); - assertFunction("parse_data_size('6917529027641081856EB')", DECIMAL, decimal("7975367974709495237422842361682067456", createDecimalType(38))); - assertFunction("parse_data_size('69175290276410818560EB')", DECIMAL, decimal("79753679747094952374228423616820674560", createDecimalType(38))); - - assertInvalidFunction("parse_data_size('')", "Invalid data size: ''"); - assertInvalidFunction("parse_data_size('0')", "Invalid data size: '0'"); - assertInvalidFunction("parse_data_size('10KB')", "Invalid data size: '10KB'"); - assertInvalidFunction("parse_data_size('KB')", "Invalid data size: 'KB'"); - assertInvalidFunction("parse_data_size('-1B')", "Invalid data size: '-1B'"); - assertInvalidFunction("parse_data_size('12345K')", "Invalid data size: '12345K'"); - assertInvalidFunction("parse_data_size('A12345B')", "Invalid data size: 'A12345B'"); - assertInvalidFunction("parse_data_size('99999999999999YB')", NUMERIC_VALUE_OUT_OF_RANGE, "Value out of range: '99999999999999YB' ('120892581961461708544797985370825293824B')"); + assertThat(assertions.function("parse_data_size", "'0B'")) + .isEqualTo(decimal("0", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'1B'")) + .isEqualTo(decimal("1", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'1.2B'")) + .isEqualTo(decimal("1", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'1.9B'")) + .isEqualTo(decimal("1", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'2.2kB'")) + .isEqualTo(decimal("2252", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'2.23kB'")) + .isEqualTo(decimal("2283", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'2.234kB'")) + .isEqualTo(decimal("2287", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'3MB'")) + .isEqualTo(decimal("3145728", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'4GB'")) + .isEqualTo(decimal("4294967296", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'4TB'")) + .isEqualTo(decimal("4398046511104", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'5PB'")) + .isEqualTo(decimal("5629499534213120", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'6EB'")) + .isEqualTo(decimal("6917529027641081856", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'7ZB'")) + .isEqualTo(decimal("8264141345021879123968", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'8YB'")) + .isEqualTo(decimal("9671406556917033397649408", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'6917529027641081856EB'")) + .isEqualTo(decimal("7975367974709495237422842361682067456", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'69175290276410818560EB'")) + .isEqualTo(decimal("79753679747094952374228423616820674560", DECIMAL)); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "''").evaluate()) + .hasMessage("Invalid data size: ''"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'0'").evaluate()) + .hasMessage("Invalid data size: '0'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'10KB'").evaluate()) + .hasMessage("Invalid data size: '10KB'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'KB'").evaluate()) + .hasMessage("Invalid data size: 'KB'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'-1B'").evaluate()) + .hasMessage("Invalid data size: '-1B'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'12345K'").evaluate()) + .hasMessage("Invalid data size: '12345K'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'A12345B'").evaluate()) + .hasMessage("Invalid data size: 'A12345B'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'99999999999999YB'").evaluate()) + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessage("Value out of range: '99999999999999YB' ('120892581961461708544797985370825293824B')"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java index e2f7d39f4842..c0e556280894 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java @@ -17,13 +17,10 @@ import io.airlift.slice.Slices; import io.trino.likematcher.LikeMatcher; import io.trino.spi.TrinoException; -import io.trino.spi.expression.StandardFunctions; -import io.trino.type.LikeFunctions; import org.testng.annotations.Test; import java.util.Optional; -import static com.google.common.base.Verify.verify; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.type.LikeFunctions.isLikePattern; @@ -48,14 +45,6 @@ private static Slice offsetHeapSlice(String value) return result.slice(2, source.length()); } - @Test - public void testFunctionNameConstantsInSync() - { - // Test may need to be updated when this changes. - verify(StandardFunctions.LIKE_PATTERN_FUNCTION_NAME.getCatalogSchema().isEmpty()); - assertEquals(StandardFunctions.LIKE_PATTERN_FUNCTION_NAME.getName(), LikeFunctions.LIKE_PATTERN_FUNCTION_NAME); - } - @Test public void testLikeBasic() { diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java new file mode 100644 index 000000000000..c1e82b14262e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java @@ -0,0 +1,155 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar.interval; + +import io.trino.sql.query.QueryAssertions; +import io.trino.type.SqlIntervalDayTime; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestIntervalDayTime +{ + protected QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testLiterals() + { + assertThat(assertions.expression("INTERVAL '12 10:45:32.123' DAY TO SECOND")) + .isEqualTo(interval(12, 10, 45, 32, 123)); + + assertThat(assertions.expression("INTERVAL '12 10:45:32.12' DAY TO SECOND")) + .isEqualTo(interval(12, 10, 45, 32, 120)); + + assertThat(assertions.expression("INTERVAL '12 10:45:32' DAY TO SECOND")) + .isEqualTo(interval(12, 10, 45, 32, 0)); + + assertThat(assertions.expression("INTERVAL '12 10:45' DAY TO SECOND")) + .isEqualTo(interval(12, 10, 45, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12 10' DAY TO SECOND")) + .isEqualTo(interval(12, 10, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12' DAY TO SECOND")) + .isEqualTo(interval(12, 0, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12 10:45' DAY TO MINUTE")) + .isEqualTo(interval(12, 10, 45, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12 10' DAY TO MINUTE")) + .isEqualTo(interval(12, 10, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12' DAY TO MINUTE")) + .isEqualTo(interval(12, 0, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12 10' DAY TO HOUR")) + .isEqualTo(interval(12, 10, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12' DAY TO HOUR")) + .isEqualTo(interval(12, 0, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '12' DAY")) + .isEqualTo(interval(12, 0, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '30' DAY")) + .isEqualTo(interval(30, 0, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '90' DAY")) + .isEqualTo(interval(90, 0, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '10:45:32.123' HOUR TO SECOND")) + .isEqualTo(interval(0, 10, 45, 32, 123)); + + assertThat(assertions.expression("INTERVAL '10:45:32.12' HOUR TO SECOND")) + .isEqualTo(interval(0, 10, 45, 32, 120)); + + assertThat(assertions.expression("INTERVAL '10:45:32' HOUR TO SECOND")) + .isEqualTo(interval(0, 10, 45, 32, 0)); + + assertThat(assertions.expression("INTERVAL '10:45' HOUR TO SECOND")) + .isEqualTo(interval(0, 10, 45, 0, 0)); + + assertThat(assertions.expression("INTERVAL '10' HOUR TO SECOND")) + .isEqualTo(interval(0, 10, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '10:45' HOUR TO MINUTE")) + .isEqualTo(interval(0, 10, 45, 0, 0)); + + assertThat(assertions.expression("INTERVAL '10' HOUR TO MINUTE")) + .isEqualTo(interval(0, 10, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '10' HOUR")) + .isEqualTo(interval(0, 10, 0, 0, 0)); + + assertThat(assertions.expression("INTERVAL '45:32.123' MINUTE TO SECOND")) + .isEqualTo(interval(0, 0, 45, 32, 123)); + + assertThat(assertions.expression("INTERVAL '45:32.12' MINUTE TO SECOND")) + .isEqualTo(interval(0, 0, 45, 32, 120)); + + assertThat(assertions.expression("INTERVAL '45:32' MINUTE TO SECOND")) + .isEqualTo(interval(0, 0, 45, 32, 0)); + + assertThat(assertions.expression("INTERVAL '45' MINUTE TO SECOND")) + .isEqualTo(interval(0, 0, 45, 0, 0)); + + assertThat(assertions.expression("INTERVAL '45' MINUTE")) + .isEqualTo(interval(0, 0, 45, 0, 0)); + + assertThat(assertions.expression("INTERVAL '32.123' SECOND")) + .isEqualTo(interval(0, 0, 0, 32, 123)); + + assertThat(assertions.expression("INTERVAL '32.12' SECOND")) + .isEqualTo(interval(0, 0, 0, 32, 120)); + + assertThat(assertions.expression("INTERVAL '32' SECOND")) + .isEqualTo(interval(0, 0, 0, 32, 0)); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '12X' DAY").evaluate()) + .hasMessage("line 1:8: '12X' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '12 10' DAY").evaluate()) + .hasMessage("line 1:8: '12 10' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '12 X' DAY TO HOUR").evaluate()) + .hasMessage("line 1:8: '12 X' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '12 -10' DAY TO HOUR").evaluate()) + .hasMessage("line 1:8: '12 -10' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '--12 -10' DAY TO HOUR").evaluate()) + .hasMessage("line 1:8: '--12 -10' is not a valid interval literal"); + } + + private static SqlIntervalDayTime interval(int day, int hour, int minute, int second, int milliseconds) + { + return new SqlIntervalDayTime(day, hour, minute, second, milliseconds); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java new file mode 100644 index 000000000000..e8b89968681b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar.interval; + +import io.trino.sql.query.QueryAssertions; +import io.trino.type.SqlIntervalYearMonth; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestIntervalYearMonth +{ + protected QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testLiterals() + { + assertThat(assertions.expression("INTERVAL '124-30' YEAR TO MONTH")) + .isEqualTo(interval(124, 30)); + + assertThat(assertions.expression("INTERVAL '124' YEAR TO MONTH")) + .isEqualTo(interval(124, 0)); + + assertThat(assertions.expression("INTERVAL '30' MONTH")) + .isEqualTo(interval(0, 30)); + + assertThat(assertions.expression("INTERVAL '32767' YEAR")) + .isEqualTo(interval(32767, 0)); + + assertThat(assertions.expression("INTERVAL '32767' MONTH")) + .isEqualTo(interval(0, 32767)); + + assertThat(assertions.expression("INTERVAL '32767-32767' YEAR TO MONTH")) + .isEqualTo(interval(32767, 32767)); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '124X' YEAR").evaluate()) + .hasMessage("line 1:8: '124X' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '124-30' YEAR").evaluate()) + .hasMessage("line 1:8: '124-30' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '124-X' YEAR TO MONTH").evaluate()) + .hasMessage("line 1:8: '124-X' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '124--30' YEAR TO MONTH").evaluate()) + .hasMessage("line 1:8: '124--30' is not a valid interval literal"); + + assertThatThrownBy(() -> assertions.expression("INTERVAL '--124--30' YEAR TO MONTH").evaluate()) + .hasMessage("line 1:8: '--124--30' is not a valid interval literal"); + } + + private static SqlIntervalYearMonth interval(int year, int month) + { + return new SqlIntervalYearMonth(year, month); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java index ea20765cabf9..3b7dc94bd03c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java @@ -135,7 +135,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java index 1b2015b5137a..3f1e17bb411a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java @@ -104,19 +104,19 @@ public void testLiterals() .hasType(createTimeType(12)) .isEqualTo(time(12, 12, 34, 56, 123_456_789_123L)); - assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234'").evaluate()) .hasMessage("line 1:8: TIME precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIME '25:00:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '25:00:00'").evaluate()) .hasMessage("line 1:8: '25:00:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:65:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:65:00'").evaluate()) .hasMessage("line 1:8: '12:65:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:65'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:65'").evaluate()) .hasMessage("line 1:8: '12:00:65' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME 'xxx'")) + assertThatThrownBy(() -> assertions.expression("TIME 'xxx'").evaluate()) .hasMessage("line 1:8: 'xxx' is not a valid time literal"); } @@ -1461,31 +1461,31 @@ public void testCastFromVarchar() assertThat(assertions.expression("CAST('23:59:59.999999999999' AS TIME(11))")).matches("TIME '00:00:00.00000000000'"); // > 12 digits of precision - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(0))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(0))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(1))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(1))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(2))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(2))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(3))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(3))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(4))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(4))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(5))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(5))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(6))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(6))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(7))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(7))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(8))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(8))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(9))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(9))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(10))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(10))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(11))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(11))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(12))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java index c34b821731c0..29c1b22e52e2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java @@ -324,7 +324,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); @@ -478,7 +478,7 @@ public void testQuarter() @Test public void testWeekOfYear() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: WEEK_OF_YEAR"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java index 774b956f0ed2..bf40774ee3f7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java @@ -134,13 +134,13 @@ public void testLiterals() .hasType(createTimestampType(12)) .isEqualTo(timestamp(12, 2020, 5, 1, 12, 34, 56, 123_456_789_012L)); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123'").evaluate()) .hasMessage("line 1:8: TIMESTAMP precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01'").evaluate()) .hasMessage("line 1:8: '2020-13-01' is not a valid timestamp literal"); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP 'xxx'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP 'xxx'").evaluate()) .hasMessage("line 1:8: 'xxx' is not a valid timestamp literal"); // negative epoch @@ -1464,9 +1464,9 @@ public void testCastToTimestampWithTimeZone() assertThat(assertions.expression("CAST(TIMESTAMP '-12001-05-01 12:34:56' AS TIMESTAMP(0) WITH TIME ZONE)")).matches("TIMESTAMP '-12001-05-01 12:34:56 Pacific/Apia'"); // Overflow - assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Out of range for timestamp with time zone: 3819379822496000"); - assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '-123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '-123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Out of range for timestamp with time zone: -3943693439888000"); } @@ -2674,30 +2674,30 @@ public void testAtTimeZone() @Test public void testCastInvalidTimestamp() { - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61"); - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java index a6c921eb7a6c..a2cddd826330 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java @@ -295,7 +295,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); @@ -600,7 +600,7 @@ public void testQuarter() @Test public void testWeekOfYear() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: WEEK_OF_YEAR"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java index 7900c185655c..ced42bbdea31 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java @@ -106,10 +106,10 @@ public void testLiterals() .hasType(createTimestampWithTimeZoneType(12)) .isEqualTo(timestampWithTimeZone(12, 2020, 5, 1, 12, 34, 56, 123_456_789_012L, getTimeZoneKey("Asia/Kathmandu"))); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123 Asia/Kathmandu'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123 Asia/Kathmandu'").evaluate()) .hasMessage("line 1:8: TIMESTAMP WITH TIME ZONE precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01 Asia/Kathmandu'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01 Asia/Kathmandu'").evaluate()) .hasMessage("line 1:8: '2020-13-01 Asia/Kathmandu' is not a valid timestamp literal"); // negative epoch @@ -2488,34 +2488,34 @@ public void testJoin() @Test public void testCastInvalidTimestamp() { - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:00 ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:00 ABC"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java index ced4bebe8661..0426c69fc1d9 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java @@ -135,7 +135,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56+08:35')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56+08:35')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java index ca98cb410b29..6d47ed319ed6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java @@ -257,34 +257,34 @@ public void testLiterals() .hasType(createTimeWithTimeZoneType(12)) .isEqualTo(timeWithTimeZone(12, 12, 34, 56, 123_456_789_123L, -14 * 60)); - assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234+08:35'").evaluate()) .hasMessage("line 1:8: TIME WITH TIME ZONE precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIME '25:00:00+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '25:00:00+08:35'").evaluate()) .hasMessage("line 1:8: '25:00:00+08:35' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:65:00+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:65:00+08:35'").evaluate()) .hasMessage("line 1:8: '12:65:00+08:35' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:65+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:65+08:35'").evaluate()) .hasMessage("line 1:8: '12:00:65+08:35' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+15:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+15:00'").evaluate()) .hasMessage("line 1:8: '12:00:00+15:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-15:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-15:00'").evaluate()) .hasMessage("line 1:8: '12:00:00-15:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+14:01'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+14:01'").evaluate()) .hasMessage("line 1:8: '12:00:00+14:01' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-14:01'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-14:01'").evaluate()) .hasMessage("line 1:8: '12:00:00-14:01' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+13:60'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+13:60'").evaluate()) .hasMessage("line 1:8: '12:00:00+13:60' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-13:60'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-13:60'").evaluate()) .hasMessage("line 1:8: '12:00:00-13:60' is not a valid time literal"); } diff --git a/core/trino-main/src/test/java/io/trino/spiller/TestAesSpillCipher.java b/core/trino-main/src/test/java/io/trino/spiller/TestAesSpillCipher.java index 46fb5d1868ef..a1abbbb0cc48 100644 --- a/core/trino-main/src/test/java/io/trino/spiller/TestAesSpillCipher.java +++ b/core/trino-main/src/test/java/io/trino/spiller/TestAesSpillCipher.java @@ -56,9 +56,7 @@ private static byte[] encryptExact(SpillCipher cipher, byte[] data) if (output.length == outLength) { return output; } - else { - return Arrays.copyOfRange(output, 0, outLength); - } + return Arrays.copyOfRange(output, 0, outLength); } private static byte[] decryptExact(SpillCipher cipher, byte[] encryptedData) @@ -68,9 +66,7 @@ private static byte[] decryptExact(SpillCipher cipher, byte[] encryptedData) if (outLength == output.length) { return output; } - else { - return Arrays.copyOfRange(output, 0, outLength); - } + return Arrays.copyOfRange(output, 0, outLength); } private static void assertFailure(ThrowingRunnable runnable, String expectedErrorMessage) diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 0fc1a2f84935..6b0dd3790460 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -3890,6 +3890,37 @@ public void testInvalidDelete() assertFails("DELETE FROM v1 WHERE a = 1") .hasErrorCode(NOT_SUPPORTED) .hasMessage("line 1:1: Deleting from views is not supported"); + assertFails("DELETE FROM mv1") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:1: Deleting from materialized views is not supported"); + } + + @Test + public void testInvalidUpdate() + { + assertFails("UPDATE foo SET bar = 'test'") + .hasErrorCode(TABLE_NOT_FOUND) + .hasMessage("line 1:1: Table 'tpch.s1.foo' does not exist"); + assertFails("UPDATE v1 SET a = 2") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:1: Updating views is not supported"); + assertFails("UPDATE mv1 SET a = 1") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:1: Updating materialized views is not supported"); + } + + @Test + public void testInvalidMerge() + { + assertFails("MERGE INTO foo USING bar ON foo.id = bar.id WHEN MATCHED THEN DELETE") + .hasErrorCode(TABLE_NOT_FOUND) + .hasMessage("line 1:1: Table 'tpch.s1.foo' does not exist"); + assertFails("MERGE INTO v1 USING t1 ON v1.a = t1.a WHEN MATCHED THEN DELETE") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:1: Merging into views is not supported"); + assertFails("MERGE INTO mv1 USING t1 ON mv1.a = t1.a WHEN MATCHED THEN DELETE") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:1: Merging into materialized views is not supported"); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java index a142146a0d52..fa038e54bf53 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java @@ -189,9 +189,7 @@ private static Page createPage(List types, boolean dictionary) if (dictionary) { return SequencePageBuilder.createSequencePageWithDictionaryBlocks(types, POSITIONS); } - else { - return SequencePageBuilder.createSequencePage(types, POSITIONS); - } + return SequencePageBuilder.createSequencePage(types, POSITIONS); } public static void main(String[] args) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index 0eea7ecde92a..d6a893b3432f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -68,7 +68,7 @@ import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NOT_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NULLIF_FUNCTION_NAME; @@ -343,7 +343,7 @@ public void testTranslateLike() new StringLiteral(pattern), Optional.empty()), new Call(BOOLEAN, - LIKE_PATTERN_FUNCTION_NAME, + LIKE_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.wrappedBuffer(pattern.getBytes(UTF_8)), createVarcharType(pattern.length()))))); @@ -354,7 +354,7 @@ public void testTranslateLike() new StringLiteral(pattern), Optional.of(new StringLiteral(escape))), new Call(BOOLEAN, - LIKE_PATTERN_FUNCTION_NAME, + LIKE_FUNCTION_NAME, List.of( new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.wrappedBuffer(pattern.getBytes(UTF_8)), createVarcharType(pattern.length())), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java index fa64db21371a..0b1d58788edc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java @@ -337,7 +337,7 @@ private LocalQueryRunner createLocalQueryRunner( new ColumnMetadata(SOURCE_COLUMN_NAME_C, VARCHAR), new ColumnMetadata(SOURCE_COLUMN_NAME_D, ROW_TYPE)); } - else if (name.equals(DESTINATION_TABLE)) { + if (name.equals(DESTINATION_TABLE)) { return ImmutableList.of( new ColumnMetadata(DESTINATION_COLUMN_NAME_A, INTEGER), new ColumnMetadata(DESTINATION_COLUMN_NAME_B, INTEGER), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelationMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelationMatcher.java index 7243f6c4e984..ed6253d641f3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelationMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelationMatcher.java @@ -72,12 +72,10 @@ private List getCorrelation(PlanNode node) if (node instanceof ApplyNode) { return ((ApplyNode) node).getCorrelation(); } - else if (node instanceof CorrelatedJoinNode) { + if (node instanceof CorrelatedJoinNode) { return ((CorrelatedJoinNode) node).getCorrelation(); } - else { - throw new IllegalStateException("Unexpected plan node: " + node); - } + throw new IllegalStateException("Unexpected plan node: " + node); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java index 941721e83c0c..fcc39625005f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java @@ -96,13 +96,11 @@ private static Map getAssignments(PlanNode node) ProjectNode projectNode = (ProjectNode) node; return projectNode.getAssignments().getMap(); } - else if (node instanceof ApplyNode) { + if (node instanceof ApplyNode) { ApplyNode applyNode = (ApplyNode) node; return applyNode.getSubqueryAssignments().getMap(); } - else { - return null; - } + return null; } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java index 8470f2aded8f..4e53c5a4f952 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/MatchResult.java @@ -69,8 +69,6 @@ public String toString() if (matches) { return "MATCH"; } - else { - return "NO MATCH"; - } + return "NO MATCH"; } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index a808ba0f85f3..e4048a1e321c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -1389,19 +1389,13 @@ public SortOrder getSortOrder() if (nullOrdering == FIRST) { return ASC_NULLS_FIRST; } - else { - return ASC_NULLS_LAST; - } + return ASC_NULLS_LAST; } - else { - checkState(ordering == DESCENDING); - if (nullOrdering == FIRST) { - return DESC_NULLS_FIRST; - } - else { - return DESC_NULLS_LAST; - } + checkState(ordering == DESCENDING); + if (nullOrdering == FIRST) { + return DESC_NULLS_FIRST; } + return DESC_NULLS_LAST; } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java index 760c49f08519..0783932d6713 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.java @@ -281,7 +281,7 @@ private MockConnectorFactory createMockFactory(Optional p.join(INNER, p.values(), p.values(), FALSE_LITERAL)) .doesNotFire(); } + + /** + * Test canonicalization of {@link CurrentTime} + */ + @Test + public void testCanonicalizeCurrentTime() + { + CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext(), tester().getTypeAnalyzer()); + tester().assertThat(canonicalizeExpressions.filterExpressionRewrite()) + .on(p -> p.filter(expression("LOCALTIMESTAMP > TIMESTAMP '2005-09-10 13:30:00'"), p.values(1))) + .matches( + filter( + "\"$localtimestamp\"(null) > TIMESTAMP '2005-09-10 13:30:00'", + values(1))); + } + + @Test + public void testCanonicalizeDateArgument() + { + CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext(), tester().getTypeAnalyzer()); + tester().assertThat(canonicalizeExpressions.filterExpressionRewrite()) + .on(p -> p.filter(expression("date(LOCALTIMESTAMP) > DATE '2005-09-10'"), p.values(1))) + .matches( + filter( + "CAST(\"$localtimestamp\"(null) AS date) > DATE '2005-09-10'", + values(1))); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java index df402d455f47..8b518bcc4051 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java @@ -13,10 +13,12 @@ */ package io.trino.sql.query; +import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionBundle; +import io.trino.spi.function.OperatorType; import io.trino.spi.type.SqlTime; import io.trino.spi.type.SqlTimeWithTimeZone; import io.trino.spi.type.SqlTimestamp; @@ -42,7 +44,9 @@ import java.io.Closeable; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -53,8 +57,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static io.trino.cost.StatsCalculator.noopStatsCalculator; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.sql.planner.assertions.PlanAssert.assertPlan; -import static io.trino.sql.query.QueryAssertions.ExpressionAssert.newExpressionAssert; import static io.trino.sql.query.QueryAssertions.QueryAssert.newQueryAssert; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -114,14 +118,38 @@ public AssertProvider query(Session session, @Language("SQL") Strin return newQueryAssert(query, runner, session); } - public AssertProvider expression(@Language("SQL") String expression) + public ExpressionAssertProvider expression(@Language("SQL") String expression) { return expression(expression, runner.getDefaultSession()); } - public AssertProvider expression(@Language("SQL") String expression, Session session) + public ExpressionAssertProvider operator(OperatorType operator, @Language("SQL") String... arguments) { - return newExpressionAssert(expression, runner, session); + return function(mangleOperatorName(operator), arguments); + } + + public ExpressionAssertProvider function(String name, @Language("SQL") String... arguments) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < arguments.length; i++) { + builder.add("a" + i); + } + + List names = builder.build(); + ExpressionAssertProvider assertion = expression("\"%s\"(%s)".formatted( + name, + String.join(",", names))); + + for (int i = 0; i < arguments.length; i++) { + assertion.binding(names.get(i), arguments[i]); + } + + return assertion; + } + + public ExpressionAssertProvider expression(@Language("SQL") String expression, Session session) + { + return new ExpressionAssertProvider(runner, session, expression); } public void assertQueryAndPlan( @@ -527,6 +555,103 @@ public QueryAssert hasCorrectResultsRegardlessOfPushdown() } } + public static class ExpressionAssertProvider + implements AssertProvider + { + private final QueryRunner runner; + private final String expression; + private final Session session; + + private final Map bindings = new HashMap<>(); + + public ExpressionAssertProvider(QueryRunner runner, Session session, String expression) + { + this.runner = runner; + this.session = session; + this.expression = expression; + } + + public ExpressionAssertProvider binding(String variable, @Language("SQL") String value) + { + String previous = bindings.put(variable, value); + if (previous != null) { + fail("%s already bound to: %s".formatted(variable, value)); + } + return this; + } + + public Result evaluate() + { + if (bindings.isEmpty()) { + return run("VALUES %s".formatted(expression)); + } + else { + List> entries = ImmutableList.copyOf(bindings.entrySet()); + + List columns = entries.stream() + .map(Map.Entry::getKey) + .collect(toList()); + + List values = entries.stream() + .map(Map.Entry::getValue) + .collect(toList()); + + // Evaluate the expression using two modes: + // 1. Avoid constant folding -> exercises the compiler and evaluation engine + // 2. Force constant folding -> exercises the interpreter + + Result full = run(""" + SELECT %s + FROM ( + VALUES (%s) + ) t(%s) + WHERE rand() >= 0 + """ + .formatted( + expression, + Joiner.on(",").join(values), + Joiner.on(",").join(columns))); + + Result withConstantFolding = run(""" + SELECT %s + FROM ( + VALUES (%s) + ) t(%s) + """ + .formatted( + expression, + Joiner.on(",").join(values), + Joiner.on(",").join(columns))); + + if (!full.type().equals(withConstantFolding.type())) { + fail("Mismatched types between interpreter and evaluation engine: %s vs %s".formatted(full.type(), withConstantFolding.type())); + } + + if (!Objects.equals(full.value(), withConstantFolding.value())) { + fail("Mismatched results between interpreter and evaluation engine: %s vs %s".formatted(full.value(), withConstantFolding.value())); + } + + return new Result(full.type(), full.value); + } + } + + private Result run(String query) + { + MaterializedResult result = runner.execute(session, query); + return new Result(result.getTypes().get(0), result.getOnlyColumnAsSet().iterator().next()); + } + + @Override + public ExpressionAssert assertThat() + { + Result result = evaluate(); + return new ExpressionAssert(runner, session, result.value(), result.type()) + .withRepresentation(ExpressionAssert.TYPE_RENDERER); + } + + record Result(Type type, Object value) {} + } + public static class ExpressionAssert extends AbstractAssert { @@ -575,15 +700,6 @@ public String toStringOf(Object object) private final Session session; private final Type actualType; - static AssertProvider newExpressionAssert(String expression, QueryRunner runner, Session session) - { - MaterializedResult result = runner.execute(session, "VALUES " + expression); - Type type = result.getTypes().get(0); - Object value = result.getOnlyColumnAsSet().iterator().next(); - return () -> new ExpressionAssert(runner, session, value, type) - .withRepresentation(TYPE_RENDERER); - } - public ExpressionAssert(QueryRunner runner, Session session, Object actual, Type actualType) { super(actual, Object.class); @@ -613,6 +729,20 @@ public ExpressionAssert matches(@Language("SQL") String expression) }); } + /** + * Syntactic sugar for: + * + *
{@code
+         *     assertThat(...)
+         *         .hasType(type)
+         *         .isNull()
+         * }
+ */ + public void isNull(Type type) + { + hasType(type).isNull(); + } + public ExpressionAssert hasType(Type type) { objects.assertEqual(info, actualType, type); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java index 72fa7712355e..493dcaca3bad 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java @@ -101,17 +101,12 @@ public void testJoinWithComplexCriteria() t2 (id, x, y) AS ( VALUES (1, 10, 'a'), - (2, 10, 'b')), - t AS ( - SELECT - x - , IF(t1.v = 0, 'cc', y) as z - FROM t1 JOIN t2 ON (t1.id = t2.id)) - SELECT * - FROM t - WHERE x = 10 AND z = 'b' + (2, 10, 'b')) + SELECT x, y + FROM t1 JOIN t2 ON (t1.id = t2.id) + WHERE IF(t1.v = 0, 'cc', y) = 'b' """)) - .matches("VALUES (10, CAST('b' AS varchar(2)))"); + .matches("VALUES (10, 'b')"); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTime.java b/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTime.java index 8843fe5f2d17..0b258bc2c8a1 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTime.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTime.java @@ -14,11 +14,9 @@ package io.trino.type; import io.trino.operator.scalar.AbstractTestFunctions; -import io.trino.spi.type.Type; import org.testng.annotations.Test; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.StandardErrorCode.INVALID_LITERAL; import static io.trino.spi.function.OperatorType.INDETERMINATE; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -42,69 +40,6 @@ public void testObject() assertEquals(new SqlIntervalDayTime(-90, 0, 0, 0, 0), new SqlIntervalDayTime(-DAYS.toMillis(90))); } - @Test - public void testLiteral() - { - assertLiteral("INTERVAL '12 10:45:32.123' DAY TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 45, 32, 123)); - assertLiteral("INTERVAL '12 10:45:32.12' DAY TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 45, 32, 120)); - assertLiteral("INTERVAL '12 10:45:32' DAY TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 45, 32, 0)); - assertLiteral("INTERVAL '12 10:45' DAY TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 45, 0, 0)); - assertLiteral("INTERVAL '12 10' DAY TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 0, 0, 0)); - assertLiteral("INTERVAL '12' DAY TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 0, 0, 0, 0)); - - assertLiteral("INTERVAL '12 10:45' DAY TO MINUTE", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 45, 0, 0)); - assertLiteral("INTERVAL '12 10' DAY TO MINUTE", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 0, 0, 0)); - assertLiteral("INTERVAL '12' DAY TO MINUTE", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 0, 0, 0, 0)); - - assertLiteral("INTERVAL '12 10' DAY TO HOUR", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 10, 0, 0, 0)); - assertLiteral("INTERVAL '12' DAY TO HOUR", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 0, 0, 0, 0)); - - assertLiteral("INTERVAL '12' DAY", INTERVAL_DAY_TIME, new SqlIntervalDayTime(12, 0, 0, 0, 0)); - assertLiteral("INTERVAL '30' DAY", INTERVAL_DAY_TIME, new SqlIntervalDayTime(30, 0, 0, 0, 0)); - assertLiteral("INTERVAL '90' DAY", INTERVAL_DAY_TIME, new SqlIntervalDayTime(90, 0, 0, 0, 0)); - - assertLiteral("INTERVAL '10:45:32.123' HOUR TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 45, 32, 123)); - assertLiteral("INTERVAL '10:45:32.12' HOUR TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 45, 32, 120)); - assertLiteral("INTERVAL '10:45:32' HOUR TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 45, 32, 0)); - assertLiteral("INTERVAL '10:45' HOUR TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 45, 0, 0)); - assertLiteral("INTERVAL '10' HOUR TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 0, 0, 0)); - - assertLiteral("INTERVAL '10:45' HOUR TO MINUTE", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 45, 0, 0)); - assertLiteral("INTERVAL '10' HOUR TO MINUTE", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 0, 0, 0)); - - assertLiteral("INTERVAL '10' HOUR", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 10, 0, 0, 0)); - - assertLiteral("INTERVAL '45:32.123' MINUTE TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 45, 32, 123)); - assertLiteral("INTERVAL '45:32.12' MINUTE TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 45, 32, 120)); - assertLiteral("INTERVAL '45:32' MINUTE TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 45, 32, 0)); - assertLiteral("INTERVAL '45' MINUTE TO SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 45, 0, 0)); - - assertLiteral("INTERVAL '45' MINUTE", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 45, 0, 0)); - - assertLiteral("INTERVAL '32.123' SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 32, 123)); - assertLiteral("INTERVAL '32.12' SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 32, 120)); - assertLiteral("INTERVAL '32' SECOND", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 32, 0)); - } - - private void assertLiteral(String projection, Type expectedType, SqlIntervalDayTime expectedValue) - { - assertFunction(projection, expectedType, expectedValue); - - projection = projection.replace("INTERVAL '", "INTERVAL '-"); - expectedValue = new SqlIntervalDayTime(-expectedValue.getMillis()); - assertFunction(projection, expectedType, expectedValue); - } - - @Test - public void testInvalidLiteral() - { - assertInvalidFunction("INTERVAL '12X' DAY", INVALID_LITERAL, "line 1:1: '12X' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '12 10' DAY", INVALID_LITERAL, "line 1:1: '12 10' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '12 X' DAY TO HOUR", INVALID_LITERAL, "line 1:1: '12 X' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '12 -10' DAY TO HOUR", INVALID_LITERAL, "line 1:1: '12 -10' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '--12 -10' DAY TO HOUR", INVALID_LITERAL, "line 1:1: '--12 -10' is not a valid interval literal"); - } - @Test public void testAdd() { diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonth.java b/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonth.java index b6e01eb24b77..14e8affc4626 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonth.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonth.java @@ -14,16 +14,13 @@ package io.trino.type; import io.trino.operator.scalar.AbstractTestFunctions; -import io.trino.spi.type.Type; import org.testng.annotations.Test; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.StandardErrorCode.INVALID_LITERAL; import static io.trino.spi.function.OperatorType.INDETERMINATE; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; -import static java.lang.String.format; import static org.testng.Assert.assertEquals; public class TestIntervalYearMonth @@ -44,40 +41,6 @@ public void testObject() assertEquals(new SqlIntervalYearMonth(-MAX_SHORT, -MAX_SHORT), new SqlIntervalYearMonth(-425_971)); } - @Test - public void testLiteral() - { - assertLiteral("INTERVAL '124-30' YEAR TO MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(124, 30)); - assertLiteral("INTERVAL '124' YEAR TO MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(124, 0)); - - assertLiteral("INTERVAL '124' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(124, 0)); - - assertLiteral("INTERVAL '30' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(0, 30)); - - assertLiteral(format("INTERVAL '%s' YEAR", MAX_SHORT), INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(MAX_SHORT, 0)); - assertLiteral(format("INTERVAL '%s' MONTH", MAX_SHORT), INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(0, MAX_SHORT)); - assertLiteral(format("INTERVAL '%s-%s' YEAR TO MONTH", MAX_SHORT, MAX_SHORT), INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(MAX_SHORT, MAX_SHORT)); - } - - private void assertLiteral(String projection, Type expectedType, SqlIntervalYearMonth expectedValue) - { - assertFunction(projection, expectedType, expectedValue); - - projection = projection.replace("INTERVAL '", "INTERVAL '-"); - expectedValue = new SqlIntervalYearMonth(-expectedValue.getMonths()); - assertFunction(projection, expectedType, expectedValue); - } - - @Test - public void testInvalidLiteral() - { - assertInvalidFunction("INTERVAL '124X' YEAR", INVALID_LITERAL, "line 1:1: '124X' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '124-30' YEAR", INVALID_LITERAL, "line 1:1: '124-30' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '124-X' YEAR TO MONTH", INVALID_LITERAL, "line 1:1: '124-X' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '124--30' YEAR TO MONTH", INVALID_LITERAL, "line 1:1: '124--30' is not a valid interval literal"); - assertInvalidFunction("INTERVAL '--124--30' YEAR TO MONTH", INVALID_LITERAL, "line 1:1: '--124--30' is not a valid interval literal"); - } - @Test public void testAdd() { diff --git a/core/trino-parser/pom.xml b/core/trino-parser/pom.xml index cafd239eafde..a96de8ee37e5 100644 --- a/core/trino-parser/pom.xml +++ b/core/trino-parser/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index d5da5ac0d77b..0a6e47fbdffb 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -3674,15 +3674,13 @@ private GrantorSpecification getGrantorSpecification(SqlBaseParser.GrantorContex if (context instanceof SqlBaseParser.SpecifiedPrincipalContext) { return new GrantorSpecification(GrantorSpecification.Type.PRINCIPAL, Optional.of(getPrincipalSpecification(((SqlBaseParser.SpecifiedPrincipalContext) context).principal()))); } - else if (context instanceof SqlBaseParser.CurrentUserGrantorContext) { + if (context instanceof SqlBaseParser.CurrentUserGrantorContext) { return new GrantorSpecification(GrantorSpecification.Type.CURRENT_USER, Optional.empty()); } - else if (context instanceof SqlBaseParser.CurrentRoleGrantorContext) { + if (context instanceof SqlBaseParser.CurrentRoleGrantorContext) { return new GrantorSpecification(GrantorSpecification.Type.CURRENT_ROLE, Optional.empty()); } - else { - throw new IllegalArgumentException("Unsupported grantor: " + context); - } + throw new IllegalArgumentException("Unsupported grantor: " + context); } private PrincipalSpecification getPrincipalSpecification(SqlBaseParser.PrincipalContext context) @@ -3690,15 +3688,13 @@ private PrincipalSpecification getPrincipalSpecification(SqlBaseParser.Principal if (context instanceof SqlBaseParser.UnspecifiedPrincipalContext) { return new PrincipalSpecification(PrincipalSpecification.Type.UNSPECIFIED, (Identifier) visit(((SqlBaseParser.UnspecifiedPrincipalContext) context).identifier())); } - else if (context instanceof SqlBaseParser.UserPrincipalContext) { + if (context instanceof SqlBaseParser.UserPrincipalContext) { return new PrincipalSpecification(PrincipalSpecification.Type.USER, (Identifier) visit(((SqlBaseParser.UserPrincipalContext) context).identifier())); } - else if (context instanceof SqlBaseParser.RolePrincipalContext) { + if (context instanceof SqlBaseParser.RolePrincipalContext) { return new PrincipalSpecification(PrincipalSpecification.Type.ROLE, (Identifier) visit(((SqlBaseParser.RolePrincipalContext) context).identifier())); } - else { - throw new IllegalArgumentException("Unsupported principal: " + context); - } + throw new IllegalArgumentException("Unsupported principal: " + context); } private static void check(boolean condition, String message, ParserRuleContext context) diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java b/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java index 3b8c856de314..dadd4e9c51fa 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java @@ -124,9 +124,7 @@ public Token recoverInline(Parser recognizer) if (nextTokensContext == null) { throw new InputMismatchException(recognizer); } - else { - throw new InputMismatchException(recognizer, nextTokensState, nextTokensContext); - } + throw new InputMismatchException(recognizer, nextTokensState, nextTokensContext); } }); diff --git a/core/trino-server-main/pom.xml b/core/trino-server-main/pom.xml index 086c376b6997..03fae00ff9b3 100644 --- a/core/trino-server-main/pom.xml +++ b/core/trino-server-main/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/core/trino-server-rpm/pom.xml b/core/trino-server-rpm/pom.xml index ff527568590a..6897e4e2e5cd 100644 --- a/core/trino-server-rpm/pom.xml +++ b/core/trino-server-rpm/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/core/trino-server-rpm/src/test/java/io/trino/server/rpm/ServerIT.java b/core/trino-server-rpm/src/test/java/io/trino/server/rpm/ServerIT.java index 0d08ad90fe46..56c37532015b 100644 --- a/core/trino-server-rpm/src/test/java/io/trino/server/rpm/ServerIT.java +++ b/core/trino-server-rpm/src/test/java/io/trino/server/rpm/ServerIT.java @@ -131,7 +131,7 @@ private static void testServer(String rpmHostPath, String expectedJavaVersion) .withCommand("sh", "-xeuc", command) .waitingFor(forLogMessage(".*SERVER STARTED.*", 1).withStartupTimeout(Duration.ofMinutes(5))) .start(); - QueryRunner queryRunner = new QueryRunner(container.getContainerIpAddress(), container.getMappedPort(8080)); + QueryRunner queryRunner = new QueryRunner(container.getHost(), container.getMappedPort(8080)); assertEquals(queryRunner.execute("SHOW CATALOGS"), ImmutableSet.of(asList("system"), asList("hive"), asList("jmx"))); assertEquals(queryRunner.execute("SELECT node_id FROM system.runtime.nodes"), ImmutableSet.of(asList("test-node-id-injected-via-env"))); // TODO remove usage of assertEventually once https://github.com/trinodb/trino/issues/2214 is fixed diff --git a/core/trino-server/pom.xml b/core/trino-server/pom.xml index 6b9809d05366..b1c7d0f0d8c1 100644 --- a/core/trino-server/pom.xml +++ b/core/trino-server/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index a87abecdd9ce..529571abfb78 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml @@ -189,6 +189,14 @@ java.method.addedToInterface method io.trino.spi.block.BlockBuilder io.trino.spi.block.BlockBuilder::newBlockBuilderLike(int, io.trino.spi.block.BlockBuilderStatus) + + java.field.removed + field io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME + + + java.method.returnTypeChanged + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlockBuilder::build() + diff --git a/core/trino-spi/src/main/java/io/trino/spi/Experimental.java b/core/trino-spi/src/main/java/io/trino/spi/Experimental.java index 33781036b1e5..956081ee22ac 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/Experimental.java +++ b/core/trino-spi/src/main/java/io/trino/spi/Experimental.java @@ -13,6 +13,7 @@ */ package io.trino.spi; +import java.lang.annotation.Documented; import java.lang.annotation.Retention; import java.lang.annotation.Target; @@ -31,6 +32,7 @@ */ @Retention(RUNTIME) @Target({TYPE, FIELD, METHOD, CONSTRUCTOR}) +@Documented public @interface Experimental { /** diff --git a/core/trino-spi/src/main/java/io/trino/spi/HostAddress.java b/core/trino-spi/src/main/java/io/trino/spi/HostAddress.java index 900a54126203..c8c2abfe017e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/HostAddress.java +++ b/core/trino-spi/src/main/java/io/trino/spi/HostAddress.java @@ -221,9 +221,7 @@ public static HostAddress fromUri(URI httpUri) if (port < 0) { return fromString(host); } - else { - return fromParts(host, port); - } + return fromParts(host, port); } /** diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/AbstractSingleMapBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/AbstractSingleMapBlock.java index 94fb15f73b63..72d5b2fb7d95 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/AbstractSingleMapBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/AbstractSingleMapBlock.java @@ -51,9 +51,7 @@ public boolean isNull(int position) } return false; } - else { - return getRawValueBlock().isNull(position / 2); - } + return getRawValueBlock().isNull(position / 2); } @Override @@ -63,9 +61,7 @@ public byte getByte(int position, int offset) if (position % 2 == 0) { return getRawKeyBlock().getByte(position / 2, offset); } - else { - return getRawValueBlock().getByte(position / 2, offset); - } + return getRawValueBlock().getByte(position / 2, offset); } @Override @@ -75,9 +71,7 @@ public short getShort(int position, int offset) if (position % 2 == 0) { return getRawKeyBlock().getShort(position / 2, offset); } - else { - return getRawValueBlock().getShort(position / 2, offset); - } + return getRawValueBlock().getShort(position / 2, offset); } @Override @@ -87,9 +81,7 @@ public int getInt(int position, int offset) if (position % 2 == 0) { return getRawKeyBlock().getInt(position / 2, offset); } - else { - return getRawValueBlock().getInt(position / 2, offset); - } + return getRawValueBlock().getInt(position / 2, offset); } @Override @@ -99,9 +91,7 @@ public long getLong(int position, int offset) if (position % 2 == 0) { return getRawKeyBlock().getLong(position / 2, offset); } - else { - return getRawValueBlock().getLong(position / 2, offset); - } + return getRawValueBlock().getLong(position / 2, offset); } @Override @@ -111,9 +101,7 @@ public Slice getSlice(int position, int offset, int length) if (position % 2 == 0) { return getRawKeyBlock().getSlice(position / 2, offset, length); } - else { - return getRawValueBlock().getSlice(position / 2, offset, length); - } + return getRawValueBlock().getSlice(position / 2, offset, length); } @Override @@ -123,9 +111,7 @@ public int getSliceLength(int position) if (position % 2 == 0) { return getRawKeyBlock().getSliceLength(position / 2); } - else { - return getRawValueBlock().getSliceLength(position / 2); - } + return getRawValueBlock().getSliceLength(position / 2); } @Override @@ -135,9 +121,7 @@ public int compareTo(int position, int offset, int length, Block otherBlock, int if (position % 2 == 0) { return getRawKeyBlock().compareTo(position / 2, offset, length, otherBlock, otherPosition, otherOffset, otherLength); } - else { - return getRawValueBlock().compareTo(position / 2, offset, length, otherBlock, otherPosition, otherOffset, otherLength); - } + return getRawValueBlock().compareTo(position / 2, offset, length, otherBlock, otherPosition, otherOffset, otherLength); } @Override @@ -147,9 +131,7 @@ public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherO if (position % 2 == 0) { return getRawKeyBlock().bytesEqual(position / 2, offset, otherSlice, otherOffset, length); } - else { - return getRawValueBlock().bytesEqual(position / 2, offset, otherSlice, otherOffset, length); - } + return getRawValueBlock().bytesEqual(position / 2, offset, otherSlice, otherOffset, length); } @Override @@ -159,9 +141,7 @@ public int bytesCompare(int position, int offset, int length, Slice otherSlice, if (position % 2 == 0) { return getRawKeyBlock().bytesCompare(position / 2, offset, length, otherSlice, otherOffset, otherLength); } - else { - return getRawValueBlock().bytesCompare(position / 2, offset, length, otherSlice, otherOffset, otherLength); - } + return getRawValueBlock().bytesCompare(position / 2, offset, length, otherSlice, otherOffset, otherLength); } @Override @@ -183,9 +163,7 @@ public boolean equals(int position, int offset, Block otherBlock, int otherPosit if (position % 2 == 0) { return getRawKeyBlock().equals(position / 2, offset, otherBlock, otherPosition, otherOffset, length); } - else { - return getRawValueBlock().equals(position / 2, offset, otherBlock, otherPosition, otherOffset, length); - } + return getRawValueBlock().equals(position / 2, offset, otherBlock, otherPosition, otherOffset, length); } @Override @@ -195,9 +173,7 @@ public long hash(int position, int offset, int length) if (position % 2 == 0) { return getRawKeyBlock().hash(position / 2, offset, length); } - else { - return getRawValueBlock().hash(position / 2, offset, length); - } + return getRawValueBlock().hash(position / 2, offset, length); } @Override @@ -207,9 +183,7 @@ public T getObject(int position, Class clazz) if (position % 2 == 0) { return getRawKeyBlock().getObject(position / 2, clazz); } - else { - return getRawValueBlock().getObject(position / 2, clazz); - } + return getRawValueBlock().getObject(position / 2, clazz); } @Override @@ -219,9 +193,7 @@ public Block getSingleValueBlock(int position) if (position % 2 == 0) { return getRawKeyBlock().getSingleValueBlock(position / 2); } - else { - return getRawValueBlock().getSingleValueBlock(position / 2); - } + return getRawValueBlock().getSingleValueBlock(position / 2); } @Override @@ -231,9 +203,7 @@ public long getEstimatedDataSizeForStats(int position) if (position % 2 == 0) { return getRawKeyBlock().getEstimatedDataSizeForStats(position / 2); } - else { - return getRawValueBlock().getEstimatedDataSizeForStats(position / 2); - } + return getRawValueBlock().getEstimatedDataSizeForStats(position / 2); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java index 1748121fa8fa..fc67f31ac585 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java @@ -23,6 +23,8 @@ import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.spi.block.ArrayBlock.createArrayBlockInternal; +import static io.trino.spi.block.BlockUtil.checkArrayRange; +import static io.trino.spi.block.BlockUtil.checkValidRegion; import static java.lang.Math.max; import static java.util.Objects.requireNonNull; @@ -42,6 +44,7 @@ public class ArrayBlockBuilder private int[] offsets = new int[1]; private boolean[] valueIsNull = new boolean[0]; private boolean hasNullValue; + private boolean hasNonNullRow; private final BlockBuilder values; private boolean currentEntryOpened; @@ -186,6 +189,7 @@ private void entryAdded(boolean isNull) offsets[positionCount + 1] = values.getPositionCount(); valueIsNull[positionCount] = isNull; hasNullValue |= isNull; + hasNonNullRow |= !isNull; positionCount++; if (blockBuilderStatus != null) { @@ -218,11 +222,14 @@ private void updateDataSize() } @Override - public ArrayBlock build() + public Block build() { if (currentEntryOpened) { throw new IllegalStateException("Current entry must be closed before the block can be built"); } + if (!hasNonNullRow) { + return nullRle(positionCount); + } return createArrayBlockInternal(0, positionCount, hasNullValue ? valueIsNull : null, offsets, values.build()); } @@ -240,4 +247,45 @@ public String toString() sb.append('}'); return sb.toString(); } + + @Override + public Block copyPositions(int[] positions, int offset, int length) + { + checkArrayRange(positions, offset, length); + + if (!hasNonNullRow) { + return nullRle(length); + } + return super.copyPositions(positions, offset, length); + } + + @Override + public Block getRegion(int position, int length) + { + int positionCount = getPositionCount(); + checkValidRegion(positionCount, position, length); + + if (!hasNonNullRow) { + return nullRle(length); + } + return super.getRegion(position, length); + } + + @Override + public Block copyRegion(int position, int length) + { + int positionCount = getPositionCount(); + checkValidRegion(positionCount, position, length); + + if (!hasNonNullRow) { + return nullRle(length); + } + return super.copyRegion(position, length); + } + + private RunLengthEncodedBlock nullRle(int positionCount) + { + ArrayBlock nullValueBlock = createArrayBlockInternal(0, 1, new boolean[] {true}, new int[] {0, 0}, values.newBlockBuilderLike(null).build()); + return new RunLengthEncodedBlock(nullValueBlock, positionCount); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java index 3d4486e797ef..0fc86d4549d1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java @@ -111,7 +111,7 @@ private Slice getValuesSlice(Block block) if (block instanceof ByteArrayBlock) { return ((ByteArrayBlock) block).getValuesSlice(); } - else if (block instanceof ByteArrayBlockBuilder) { + if (block instanceof ByteArrayBlockBuilder) { return ((ByteArrayBlockBuilder) block).getValuesSlice(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarArray.java b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarArray.java index 399903b47ec4..6e61decf0689 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarArray.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarArray.java @@ -117,7 +117,7 @@ private static ColumnarArray toColumnarArray(RunLengthEncodedBlock rleBlock) private ColumnarArray(Block nullCheckBlock, int offsetsOffset, int[] offsets, Block elementsBlock) { - this.nullCheckBlock = nullCheckBlock; + this.nullCheckBlock = requireNonNull(nullCheckBlock, "nullCheckBlock is null"); this.offsetsOffset = offsetsOffset; this.offsets = offsets; this.elementsBlock = elementsBlock; @@ -128,6 +128,11 @@ public int getPositionCount() return nullCheckBlock.getPositionCount(); } + public boolean mayHaveNull() + { + return nullCheckBlock.mayHaveNull(); + } + public boolean isNull(int position) { return nullCheckBlock.isNull(position); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarMap.java b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarMap.java index 9ca7878a860d..ee8bbfc584cc 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarMap.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarMap.java @@ -119,7 +119,7 @@ private static ColumnarMap toColumnarMap(RunLengthEncodedBlock rleBlock) private ColumnarMap(Block nullCheckBlock, int offsetsOffset, int[] offsets, Block keysBlock, Block valuesBlock) { - this.nullCheckBlock = nullCheckBlock; + this.nullCheckBlock = requireNonNull(nullCheckBlock, "nullCheckBlock is null"); this.offsetsOffset = offsetsOffset; this.offsets = offsets; this.keysBlock = keysBlock; @@ -131,6 +131,11 @@ public int getPositionCount() return nullCheckBlock.getPositionCount(); } + public boolean mayHaveNull() + { + return nullCheckBlock.mayHaveNull(); + } + public boolean isNull(int position) { return nullCheckBlock.isNull(position); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java index 45625b697d24..6ce731e4e8cc 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java @@ -90,7 +90,7 @@ private Slice getValuesSlice(Block block) if (block instanceof Int128ArrayBlock) { return ((Int128ArrayBlock) block).getValuesSlice(); } - else if (block instanceof Int128ArrayBlockBuilder) { + if (block instanceof Int128ArrayBlockBuilder) { return ((Int128ArrayBlockBuilder) block).getValuesSlice(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int96ArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int96ArrayBlockEncoding.java index 8e9da83a5be5..c94421cbfafe 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int96ArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int96ArrayBlockEncoding.java @@ -121,7 +121,7 @@ private Slice getHighSlice(Block block) if (block instanceof Int96ArrayBlock) { return ((Int96ArrayBlock) block).getHighSlice(); } - else if (block instanceof Int96ArrayBlockBuilder) { + if (block instanceof Int96ArrayBlockBuilder) { return ((Int96ArrayBlockBuilder) block).getHighSlice(); } @@ -133,7 +133,7 @@ private Slice getLowSlice(Block block) if (block instanceof Int96ArrayBlock) { return ((Int96ArrayBlock) block).getLowSlice(); } - else if (block instanceof Int96ArrayBlockBuilder) { + if (block instanceof Int96ArrayBlockBuilder) { return ((Int96ArrayBlockBuilder) block).getLowSlice(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java index 86483ca9be0a..038ee2b98ee8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java @@ -111,7 +111,7 @@ private Slice getValuesSlice(Block block) if (block instanceof IntArrayBlock) { return ((IntArrayBlock) block).getValuesSlice(); } - else if (block instanceof IntArrayBlockBuilder) { + if (block instanceof IntArrayBlockBuilder) { return ((IntArrayBlockBuilder) block).getValuesSlice(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java index 14819c8d363b..b453e854d114 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java @@ -111,7 +111,7 @@ private Slice getValuesSlice(Block block) if (block instanceof LongArrayBlock) { return ((LongArrayBlock) block).getValuesSlice(); } - else if (block instanceof LongArrayBlockBuilder) { + if (block instanceof LongArrayBlockBuilder) { return ((LongArrayBlockBuilder) block).getValuesSlice(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java index 5c98827bc2ad..49c3591f3270 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java @@ -111,7 +111,7 @@ private Slice getValuesSlice(Block block) if (block instanceof ShortArrayBlock) { return ((ShortArrayBlock) block).getValuesSlice(); } - else if (block instanceof ShortArrayBlockBuilder) { + if (block instanceof ShortArrayBlockBuilder) { return ((ShortArrayBlockBuilder) block).getValuesSlice(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/SingleRowBlockWriter.java b/core/trino-spi/src/main/java/io/trino/spi/block/SingleRowBlockWriter.java index 4d972f97db15..0b0e10800f13 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/SingleRowBlockWriter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/SingleRowBlockWriter.java @@ -216,9 +216,7 @@ public String toString() if (!fieldBlockBuilderReturned) { return format("SingleRowBlockWriter{numFields=%d, fieldBlockBuilderReturned=false, positionCount=%d}", fieldBlockBuilders.length, getPositionCount()); } - else { - return format("SingleRowBlockWriter{numFields=%d, fieldBlockBuilderReturned=true}", fieldBlockBuilders.length); - } + return format("SingleRowBlockWriter{numFields=%d, fieldBlockBuilderReturned=true}", fieldBlockBuilders.length); } void setRowIndex(int rowIndex) diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java index 2fcab2bd3e3c..71db3ba0e661 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java @@ -85,6 +85,23 @@ public VariableWidthBlock(int positionCount, Slice slice, int[] offsets, Optiona retainedSizeInBytes = INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull) + sizeOf(offsets); } + /** + * Gets the raw {@link Slice} that keeps the actual data bytes. + */ + public Slice getRawSlice() + { + return slice; + } + + /** + * Gets the offset of the value at the {@code position} in the {@link Slice} returned by {@link #getRawSlice())}. + */ + public int getRawSliceOffset(int position) + { + checkReadablePosition(this, position); + return getPositionOffset(position); + } + @Override protected final int getPositionOffset(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index aa0bc2f7a015..3844e12bf4b8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -114,19 +114,6 @@ default ConnectorTableHandle getTableHandle( throw new TrinoException(NOT_SUPPORTED, "This connector does not support versioned tables"); } - /** - * Returns a table handle for the specified table name, or null if the connector does not contain the table. - * The returned table handle can contain information in analyzeProperties. - * - * @deprecated use {@link #getStatisticsCollectionMetadata(ConnectorSession, ConnectorTableHandle, Map)} - */ - @Deprecated - @Nullable - default ConnectorTableHandle getTableHandleForStatisticsCollection(ConnectorSession session, SchemaTableName tableName, Map analyzeProperties) - { - throw new TrinoException(NOT_SUPPORTED, "This connector does not support analyze"); - } - /** * Create initial handle for execution of table procedure. The handle will be used through planning process. It will be converted to final * handle used for execution via @{link {@link ConnectorMetadata#beginTableExecute} @@ -478,26 +465,12 @@ default TableStatisticsMetadata getStatisticsCollectionMetadataForWrite(Connecto return TableStatisticsMetadata.empty(); } - /** - * Describe statistics that must be collected during a statistics collection - * - * @deprecated use {@link #getStatisticsCollectionMetadata(ConnectorSession, ConnectorTableHandle, Map)} - */ - @Deprecated - default TableStatisticsMetadata getStatisticsCollectionMetadata(ConnectorSession session, ConnectorTableMetadata tableMetadata) - { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "ConnectorMetadata getTableHandleForStatisticsCollection() is implemented without getStatisticsCollectionMetadata()"); - } - /** * Describe statistics that must be collected during a statistics collection */ default ConnectorAnalyzeMetadata getStatisticsCollectionMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, Map analyzeProperties) { - SchemaTableName tableName = getTableMetadata(session, tableHandle).getTable(); - ConnectorTableHandle analyzeHandle = getTableHandleForStatisticsCollection(session, tableName, analyzeProperties); - TableStatisticsMetadata statisticsMetadata = getStatisticsCollectionMetadata(session, getTableMetadata(session, analyzeHandle)); - return new ConnectorAnalyzeMetadata(analyzeHandle, statisticsMetadata); + throw new TrinoException(NOT_SUPPORTED, "This connector does not support analyze"); } /** diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java index 20130c578eb2..0e5cfbb7518c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java @@ -16,7 +16,7 @@ import io.trino.spi.Experimental; /* - * Implementation is expected to be Jackson serializable and include equals, hashCode and toString methods + * Implementation is expected to be Jackson serializable */ @Experimental(eta = "2023-01-01") public interface ExchangeSourceHandle @@ -26,13 +26,4 @@ public interface ExchangeSourceHandle long getDataSizeInBytes(); long getRetainedSizeInBytes(); - - @Override - boolean equals(Object obj); - - @Override - int hashCode(); - - @Override - String toString(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java index 2784697f6575..a5d1bae08e5b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java @@ -81,7 +81,7 @@ private StandardFunctions() {} */ public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate"); - public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern"); + public static final FunctionName LIKE_FUNCTION_NAME = new FunctionName("$like"); /** * {@code $in(value, array)} returns {@code true} when value is equal to an element of the array, diff --git a/core/trino-spi/src/main/java/io/trino/spi/predicate/Domain.java b/core/trino-spi/src/main/java/io/trino/spi/predicate/Domain.java index 5571832479c9..3116f157b269 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/predicate/Domain.java +++ b/core/trino-spi/src/main/java/io/trino/spi/predicate/Domain.java @@ -146,9 +146,7 @@ public boolean isNullableSingleValue() if (nullAllowed) { return values.isNone(); } - else { - return values.isSingleValue(); - } + return values.isSingleValue(); } public boolean isOnlyNull() @@ -173,9 +171,7 @@ public Object getNullableSingleValue() if (nullAllowed) { return null; } - else { - return values.getSingleValue(); - } + return values.getSingleValue(); } public boolean includesNullableValue(Object value) diff --git a/core/trino-spi/src/main/java/io/trino/spi/predicate/EquatableValueSet.java b/core/trino-spi/src/main/java/io/trino/spi/predicate/EquatableValueSet.java index 9d543635c3e9..e05c10afb52f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/predicate/EquatableValueSet.java +++ b/core/trino-spi/src/main/java/io/trino/spi/predicate/EquatableValueSet.java @@ -243,15 +243,13 @@ public EquatableValueSet intersect(ValueSet other) if (inclusive && otherValueSet.inclusive()) { return new EquatableValueSet(type, true, intersect(entries, otherValueSet.entries)); } - else if (inclusive) { + if (inclusive) { return new EquatableValueSet(type, true, subtract(entries, otherValueSet.entries)); } - else if (otherValueSet.inclusive()) { + if (otherValueSet.inclusive()) { return new EquatableValueSet(type, true, subtract(otherValueSet.entries, entries)); } - else { - return new EquatableValueSet(type, false, union(otherValueSet.entries, entries)); - } + return new EquatableValueSet(type, false, union(otherValueSet.entries, entries)); } @Override @@ -262,15 +260,13 @@ public boolean overlaps(ValueSet other) if (inclusive && otherValueSet.inclusive()) { return setsOverlap(entries, otherValueSet.entries); } - else if (inclusive) { + if (inclusive) { return !otherValueSet.entries.containsAll(entries); } - else if (otherValueSet.inclusive()) { + if (otherValueSet.inclusive()) { return !entries.containsAll(otherValueSet.entries); } - else { - return true; - } + return true; } @Override @@ -281,15 +277,13 @@ public EquatableValueSet union(ValueSet other) if (inclusive && otherValueSet.inclusive()) { return new EquatableValueSet(type, true, union(entries, otherValueSet.entries)); } - else if (inclusive) { + if (inclusive) { return new EquatableValueSet(type, false, subtract(otherValueSet.entries, entries)); } - else if (otherValueSet.inclusive()) { + if (otherValueSet.inclusive()) { return new EquatableValueSet(type, false, subtract(entries, otherValueSet.entries)); } - else { - return new EquatableValueSet(type, false, intersect(otherValueSet.entries, entries)); - } + return new EquatableValueSet(type, false, intersect(otherValueSet.entries, entries)); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java index 33d703fe8ccd..70f53bb3b386 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java @@ -650,6 +650,17 @@ default void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context denyGrantExecuteFunctionPrivilege(functionName, context.getIdentity(), granteeAsString); } + /** + * Check if identity is allowed to grant an access to the function execution to grantee. + * + * @throws AccessDeniedException if not allowed + */ + default void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, FunctionKind functionKind, CatalogSchemaRoutineName functionName, TrinoPrincipal grantee, boolean grantOption) + { + String granteeAsString = format("%s '%s'", grantee.getType().name().toLowerCase(Locale.ENGLISH), grantee.getName()); + denyGrantExecuteFunctionPrivilege(functionName.toString(), context.getIdentity(), granteeAsString); + } + /** * Check if identity is allowed to set the specified property in a catalog. * diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java index f1ed1506f37e..ac29e3095cd0 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java @@ -195,10 +195,8 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block instanceof AbstractArrayBlock) { return ((AbstractArrayBlock) block).apply((valuesBlock, start, length) -> arrayBlockToObjectValues(session, valuesBlock, start, length), position); } - else { - Block arrayBlock = block.getObject(position, Block.class); - return arrayBlockToObjectValues(session, arrayBlock, 0, arrayBlock.getPositionCount()); - } + Block arrayBlock = block.getObject(position, Block.class); + return arrayBlockToObjectValues(session, arrayBlock, 0, arrayBlock.getPositionCount()); } private List arrayBlockToObjectValues(ConnectorSession session, Block block, int start, int length) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/Chars.java b/core/trino-spi/src/main/java/io/trino/spi/type/Chars.java index 1a55fdbf50a8..4e3cd7512e53 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/Chars.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/Chars.java @@ -135,9 +135,7 @@ public static int compareChars(Slice left, Slice right) if (left.length() < right.length()) { return compareCharsShorterToLonger(left, right); } - else { - return -compareCharsShorterToLonger(right, left); - } + return -compareCharsShorterToLonger(right, left); } private static int compareCharsShorterToLonger(Slice shorter, Slice longer) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java index 3617e94b4d20..c83c661e7583 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java @@ -43,9 +43,7 @@ public static DecimalType createDecimalType(int precision, int scale) if (precision <= MAX_SHORT_PRECISION) { return new ShortDecimalType(precision, scale); } - else { - return new LongDecimalType(precision, scale); - } + return new LongDecimalType(precision, scale); } public static DecimalType createDecimalType(int precision) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java b/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java index 789b6c17a418..87c0a6591632 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java @@ -66,7 +66,7 @@ public static Int128 fromBigEndian(byte[] bytes) return Int128.valueOf(high, low); } - else if (bytes.length > 8) { + if (bytes.length > 8) { // read the last 8 bytes into low int offset = bytes.length - Long.BYTES; long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset); @@ -80,21 +80,19 @@ else if (bytes.length > 8) { return Int128.valueOf(high, low); } - else if (bytes.length == 8) { + if (bytes.length == 8) { long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, 0); long high = (low >> 63); return Int128.valueOf(high, low); } - else { - long high = (bytes[0] >> 7); - long low = high; - for (int i = 0; i < bytes.length; i++) { - low = (low << 8) | (bytes[i] & 0xFF); - } - - return Int128.valueOf(high, low); + long high = (bytes[0] >> 7); + long low = high; + for (int i = 0; i < bytes.length; i++) { + low = (low << 8) | (bytes[i] & 0xFF); } + + return Int128.valueOf(high, low); } public static Int128 valueOf(long[] value) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/SqlVarbinary.java b/core/trino-spi/src/main/java/io/trino/spi/type/SqlVarbinary.java index 2ef2ddb047d3..ea79baa64207 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/SqlVarbinary.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/SqlVarbinary.java @@ -40,7 +40,7 @@ public int compareTo(SqlVarbinary obj) if (bytes[i] < obj.bytes[i]) { return -1; } - else if (bytes[i] > obj.bytes[i]) { + if (bytes[i] > obj.bytes[i]) { return 1; } } diff --git a/core/trino-spi/src/test/java/io/trino/spi/block/TestArrayBlockBuilder.java b/core/trino-spi/src/test/java/io/trino/spi/block/TestArrayBlockBuilder.java index 5d669ac76f4c..9b69753502b2 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/block/TestArrayBlockBuilder.java +++ b/core/trino-spi/src/test/java/io/trino/spi/block/TestArrayBlockBuilder.java @@ -73,4 +73,36 @@ public void testConcurrentWriting() .isInstanceOf(IllegalStateException.class) .hasMessage("Expected current entry to be closed but was opened"); } + + @Test + public void testBuilderProducesNullRleForNullRows() + { + // empty block + assertIsNullRle(blockBuilder().build(), 0); + + // single null + assertIsNullRle(blockBuilder().appendNull().build(), 1); + + // multiple nulls + assertIsNullRle(blockBuilder().appendNull().appendNull().build(), 2); + + BlockBuilder blockBuilder = blockBuilder().appendNull().appendNull(); + assertIsNullRle(blockBuilder.copyPositions(new int[] {0}, 0, 1), 1); + assertIsNullRle(blockBuilder.getRegion(0, 1), 1); + assertIsNullRle(blockBuilder.copyRegion(0, 1), 1); + } + + private static BlockBuilder blockBuilder() + { + return new ArrayBlockBuilder(BIGINT, null, 10); + } + + private void assertIsNullRle(Block block, int expectedPositionCount) + { + assertEquals(block.getPositionCount(), expectedPositionCount); + assertEquals(block.getClass(), RunLengthEncodedBlock.class); + if (expectedPositionCount > 0) { + assertTrue(block.isNull(0)); + } + } } diff --git a/docs/pom.xml b/docs/pom.xml index 9c1e59b5d4bc..e018ad1d857b 100644 --- a/docs/pom.xml +++ b/docs/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT trino-docs diff --git a/docs/src/main/sphinx/connector/bigquery.rst b/docs/src/main/sphinx/connector/bigquery.rst index 17befa3b7500..a5104fdc527b 100644 --- a/docs/src/main/sphinx/connector/bigquery.rst +++ b/docs/src/main/sphinx/connector/bigquery.rst @@ -299,6 +299,7 @@ the following features: * :doc:`/sql/drop-table` * :doc:`/sql/create-schema` * :doc:`/sql/drop-schema` +* :doc:`/sql/comment` Table functions --------------- diff --git a/docs/src/main/sphinx/connector/clickhouse.rst b/docs/src/main/sphinx/connector/clickhouse.rst index c78bab459c44..e8b9537815dc 100644 --- a/docs/src/main/sphinx/connector/clickhouse.rst +++ b/docs/src/main/sphinx/connector/clickhouse.rst @@ -92,6 +92,9 @@ configured connector to create a catalog named ``sales``. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``1000`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/delta-lake.rst b/docs/src/main/sphinx/connector/delta-lake.rst index 6226931431ee..cc17059d0413 100644 --- a/docs/src/main/sphinx/connector/delta-lake.rst +++ b/docs/src/main/sphinx/connector/delta-lake.rst @@ -237,6 +237,37 @@ configure processing of Parquet files. * - ``parquet_writer_page_size`` - The maximum page size created by the Parquet writer. +.. _delta-lake-authorization: + +Authorization checks +^^^^^^^^^^^^^^^^^^^^ + +You can enable authorization checks for the connector by setting +the ``delta.security`` property in the catalog properties file. This +property must be one of the following values: + +.. list-table:: Delta Lake security values + :widths: 30, 60 + :header-rows: 1 + + * - Property value + - Description + * - ``ALLOW_ALL`` (default value) + - No authorization checks are enforced. + * - ``SYSTEM`` + - The connector relies on system-level access control. + * - ``READ_ONLY`` + - Operations that read data or metadata, such as :doc:`/sql/select` are + permitted. No operations that write data or metadata, such as + :doc:`/sql/create-table`, :doc:`/sql/insert`, or :doc:`/sql/delete` are + allowed. + * - ``FILE`` + - Authorization checks are enforced using a catalog-level access control + configuration file whose path is specified in the ``security.config-file`` + catalog configuration property. See + :ref:`catalog-file-based-access-control` for information on the + authorization configuration file. + .. _delta-lake-type-mapping: Type mapping diff --git a/docs/src/main/sphinx/connector/druid.rst b/docs/src/main/sphinx/connector/druid.rst index b5536d513642..add277d748b2 100644 --- a/docs/src/main/sphinx/connector/druid.rst +++ b/docs/src/main/sphinx/connector/druid.rst @@ -48,6 +48,9 @@ name from the properties file. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``32`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment @@ -57,6 +60,45 @@ name from the properties file. Type mapping ------------ +Because Trino and Druid each support types that the other does not, this +connector modifies some types when reading data. + +Druid type to Trino type mapping +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The connector maps Druid types to the corresponding Trino types according to the +following table: + +.. list-table:: Druid type to Trino type mapping + :widths: 30, 30, 50 + :header-rows: 1 + + * - Druid type + - Trino type + - Notes + * - ``STRING`` + - ``VARCHAR`` + - + * - ``FLOAT`` + - ``REAL`` + - + * - ``DOUBLE`` + - ``DOUBLE`` + - + * - ``LONG`` + - ``BIGINT`` + - Except for the special ``_time`` column, which is mapped to ``TIMESTAMP``. + * - ``TIMESTAMP`` + - ``TIMESTAMP`` + - Only applicable to the special ``_time`` column. + +No other data types are supported. + +Druid does not have a real ``NULL`` value for any data type. By +default, Druid treats ``NULL`` as the default value for a data type. For +example, ``LONG`` would be ``0``, ``DOUBLE`` would be ``0.0``, ``STRING`` would +be an empty string ``''``, and so forth. + .. include:: jdbc-type-mapping.fragment .. _druid-sql-support: diff --git a/docs/src/main/sphinx/connector/jdbc-common-configurations.fragment b/docs/src/main/sphinx/connector/jdbc-common-configurations.fragment index cce75a4b79fb..0fd8da27fef1 100644 --- a/docs/src/main/sphinx/connector/jdbc-common-configurations.fragment +++ b/docs/src/main/sphinx/connector/jdbc-common-configurations.fragment @@ -51,10 +51,3 @@ connector: Using a large timeout can potentially result in more detailed dynamic filters. However, it can also increase latency for some queries. - ``20s`` - * - ``domain-compaction-threshold`` - - Minimum size of query predicates above which Trino compacts the predicates. - Pushing down a large list of predicates to the data source can compromise - performance. For optimization in that situation, Trino can compact the large - predicates into a simpler range predicate. If necessary, adjust the threshold - to ensure a balance between performance and predicate pushdown. - - ``32`` diff --git a/docs/src/main/sphinx/connector/jdbc-domain-compaction-threshold.fragment b/docs/src/main/sphinx/connector/jdbc-domain-compaction-threshold.fragment new file mode 100644 index 000000000000..73b077f45836 --- /dev/null +++ b/docs/src/main/sphinx/connector/jdbc-domain-compaction-threshold.fragment @@ -0,0 +1,14 @@ +Domain compaction threshold +""""""""""""""""""""""""""" + +Pushing down a large list of predicates to the data source can compromise +performance. Trino compacts large predicates into a simpler range predicate +by default to ensure a balance between performance and predicate pushdown. +If necessary, the threshold for this compaction can be increased to improve +performance when the data source is capable of taking advantage of large +predicates. Increasing this threshold may improve pushdown of large +:doc:`dynamic filters `. +The ``domain-compaction-threshold`` catalog configuration property or the +``domain_compaction_threshold`` :ref:`catalog session property +` can be used to adjust the default value of +|default_domain_compaction_threshold| for this threshold. diff --git a/docs/src/main/sphinx/connector/mariadb.rst b/docs/src/main/sphinx/connector/mariadb.rst index 96815b7ca4b7..eb019d8c1d92 100644 --- a/docs/src/main/sphinx/connector/mariadb.rst +++ b/docs/src/main/sphinx/connector/mariadb.rst @@ -36,6 +36,9 @@ connection properties as appropriate for your setup: .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``32`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-case-insensitive-matching.fragment .. include:: non-transactional-insert.fragment diff --git a/docs/src/main/sphinx/connector/memsql.rst b/docs/src/main/sphinx/connector/memsql.rst index c7832367b3ab..cd9936272865 100644 --- a/docs/src/main/sphinx/connector/memsql.rst +++ b/docs/src/main/sphinx/connector/memsql.rst @@ -78,6 +78,9 @@ will create a catalog named ``sales`` using the configured connector. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``32`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/mysql.rst b/docs/src/main/sphinx/connector/mysql.rst index 1e9773d94d40..9650d09d441c 100644 --- a/docs/src/main/sphinx/connector/mysql.rst +++ b/docs/src/main/sphinx/connector/mysql.rst @@ -91,6 +91,9 @@ creates a catalog named ``sales`` using the configured connector. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``32`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/oracle.rst b/docs/src/main/sphinx/connector/oracle.rst index c9005e02f8dc..585f5082472e 100644 --- a/docs/src/main/sphinx/connector/oracle.rst +++ b/docs/src/main/sphinx/connector/oracle.rst @@ -85,6 +85,9 @@ you name the property file ``sales.properties``, Trino creates a catalog named .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``1000`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/phoenix.rst b/docs/src/main/sphinx/connector/phoenix.rst index 6dc2bdeb6638..c983271bbe50 100644 --- a/docs/src/main/sphinx/connector/phoenix.rst +++ b/docs/src/main/sphinx/connector/phoenix.rst @@ -57,6 +57,9 @@ Property name Required Description .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``5000`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/postgresql.rst b/docs/src/main/sphinx/connector/postgresql.rst index 50b7e44affef..7839cff37be0 100644 --- a/docs/src/main/sphinx/connector/postgresql.rst +++ b/docs/src/main/sphinx/connector/postgresql.rst @@ -86,6 +86,9 @@ catalog named ``sales`` using the configured connector. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``32`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/redshift.rst b/docs/src/main/sphinx/connector/redshift.rst index de19ab567fe9..3f22072da32d 100644 --- a/docs/src/main/sphinx/connector/redshift.rst +++ b/docs/src/main/sphinx/connector/redshift.rst @@ -72,6 +72,9 @@ catalog named ``sales`` using the configured connector. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``32`` +.. include:: jdbc-domain-compaction-threshold.fragment + .. include:: jdbc-procedures.fragment .. include:: jdbc-case-insensitive-matching.fragment diff --git a/docs/src/main/sphinx/connector/sqlserver.rst b/docs/src/main/sphinx/connector/sqlserver.rst index 5a25436ee2ec..7eb18c46024a 100644 --- a/docs/src/main/sphinx/connector/sqlserver.rst +++ b/docs/src/main/sphinx/connector/sqlserver.rst @@ -84,6 +84,9 @@ catalog named ``sales`` using the configured connector. .. include:: jdbc-common-configurations.fragment +.. |default_domain_compaction_threshold| replace:: ``500`` +.. include:: jdbc-domain-compaction-threshold.fragment + Specific configuration properties ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/main/sphinx/release.rst b/docs/src/main/sphinx/release.rst index 068b3ac15b69..eaa08d9860a6 100644 --- a/docs/src/main/sphinx/release.rst +++ b/docs/src/main/sphinx/release.rst @@ -10,6 +10,7 @@ Release notes .. toctree:: :maxdepth: 1 + release/release-395 release/release-394 release/release-393 release/release-392 diff --git a/docs/src/main/sphinx/release/release-383.md b/docs/src/main/sphinx/release/release-383.md index 18eebdde08de..a078992986a6 100644 --- a/docs/src/main/sphinx/release/release-383.md +++ b/docs/src/main/sphinx/release/release-383.md @@ -1,4 +1,4 @@ -# Release 383 (1 June 2022) +# Release 383 (1 Jun 2022) ```{warning} This release has a regression that may cause queries to fail. diff --git a/docs/src/main/sphinx/release/release-384.md b/docs/src/main/sphinx/release/release-384.md index 862f4271c120..cb5f406536ac 100644 --- a/docs/src/main/sphinx/release/release-384.md +++ b/docs/src/main/sphinx/release/release-384.md @@ -1,4 +1,4 @@ -# Release 384 (3 June 2022) +# Release 384 (3 Jun 2022) ## General diff --git a/docs/src/main/sphinx/release/release-385.md b/docs/src/main/sphinx/release/release-385.md index 48361f0d6d2f..75a940bcbc72 100644 --- a/docs/src/main/sphinx/release/release-385.md +++ b/docs/src/main/sphinx/release/release-385.md @@ -1,4 +1,4 @@ -# Release 385 (8 June 2022) +# Release 385 (8 Jun 2022) ## General diff --git a/docs/src/main/sphinx/release/release-386.md b/docs/src/main/sphinx/release/release-386.md index e8d739d9228d..50645ff16711 100644 --- a/docs/src/main/sphinx/release/release-386.md +++ b/docs/src/main/sphinx/release/release-386.md @@ -1,4 +1,4 @@ -# Release 386 (15 June 2022) +# Release 386 (15 Jun 2022) ## General diff --git a/docs/src/main/sphinx/release/release-387.md b/docs/src/main/sphinx/release/release-387.md index dcbc8bfcc464..2d34d27b571c 100644 --- a/docs/src/main/sphinx/release/release-387.md +++ b/docs/src/main/sphinx/release/release-387.md @@ -1,4 +1,4 @@ -# Release 387 (22 June 2022) +# Release 387 (22 Jun 2022) ## General diff --git a/docs/src/main/sphinx/release/release-388.md b/docs/src/main/sphinx/release/release-388.md index bd988153d4eb..04c82f40b890 100644 --- a/docs/src/main/sphinx/release/release-388.md +++ b/docs/src/main/sphinx/release/release-388.md @@ -1,4 +1,4 @@ -# Release 388 (29 June 2022) +# Release 388 (29 Jun 2022) ## General diff --git a/docs/src/main/sphinx/release/release-389.md b/docs/src/main/sphinx/release/release-389.md index 172e74c12893..59418811db1e 100644 --- a/docs/src/main/sphinx/release/release-389.md +++ b/docs/src/main/sphinx/release/release-389.md @@ -1,5 +1,4 @@ - -# Release 389 (7 July 2022) +# Release 389 (7 Jul 2022) ## General diff --git a/docs/src/main/sphinx/release/release-390.md b/docs/src/main/sphinx/release/release-390.md index 3be31c77422c..7460a5b50737 100644 --- a/docs/src/main/sphinx/release/release-390.md +++ b/docs/src/main/sphinx/release/release-390.md @@ -1,4 +1,4 @@ -# Release 390 (13 July 2022) +# Release 390 (13 Jul 2022) ## General diff --git a/docs/src/main/sphinx/release/release-391.md b/docs/src/main/sphinx/release/release-391.md index 6ca10aa53b51..ac3876a475b8 100644 --- a/docs/src/main/sphinx/release/release-391.md +++ b/docs/src/main/sphinx/release/release-391.md @@ -1,4 +1,4 @@ -# Release 391 (22 July 2022) +# Release 391 (22 Jul 2022) ## General diff --git a/docs/src/main/sphinx/release/release-392.md b/docs/src/main/sphinx/release/release-392.md index 06afffacf843..48ef3c7f8486 100644 --- a/docs/src/main/sphinx/release/release-392.md +++ b/docs/src/main/sphinx/release/release-392.md @@ -1,4 +1,4 @@ -# Release 392 (3 August 2022) +# Release 392 (3 Aug 2022) ## General diff --git a/docs/src/main/sphinx/release/release-393.md b/docs/src/main/sphinx/release/release-393.md index f00460750757..033df6588532 100644 --- a/docs/src/main/sphinx/release/release-393.md +++ b/docs/src/main/sphinx/release/release-393.md @@ -1,4 +1,4 @@ -# Release 393 (17 August 2022) +# Release 393 (17 Aug 2022) ## General diff --git a/docs/src/main/sphinx/release/release-394.md b/docs/src/main/sphinx/release/release-394.md index 50ee0290b5d3..80dddbe3fc68 100644 --- a/docs/src/main/sphinx/release/release-394.md +++ b/docs/src/main/sphinx/release/release-394.md @@ -1,4 +1,4 @@ -# Release 394 (29 August 2022) +# Release 394 (29 Aug 2022) ## General diff --git a/docs/src/main/sphinx/release/release-395.md b/docs/src/main/sphinx/release/release-395.md new file mode 100644 index 000000000000..a34414f2a077 --- /dev/null +++ b/docs/src/main/sphinx/release/release-395.md @@ -0,0 +1,102 @@ +# Release 395 (7 Sep 2022) + +## General + +* Reduce memory consumption when fault-tolerant execution is enabled. ({issue}`13855`) +* Reduce memory consumption of aggregations. ({issue}`12512`) +* Improve performance of aggregations with decimals. ({issue}`13573`) +* Improve concurrency for large clusters. ({issue}`13934`, `13986`) +* Remove `information_schema.role_authorization_descriptors` table. ({issue}`11341`) +* Fix `SHOW CREATE TABLE` or `SHOW COLUMNS` showing an invalid type for columns + that use a reserved keyword as column name. ({issue}`13483`) + +## ClickHouse connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## Delta Lake connector + +* Add support for the `ALTER TABLE ... RENAME TO` statement with a Glue + metastore. ({issue}`12985`) +* Improve performance of inserts by automatically scaling the number of writers + within a worker node. ({issue}`13111`) +* Enforce `delta.checkpoint.writeStatsAsJson` and + `delta.checkpoint.writeStatsAsStruct` table properties to ensure table + statistics are written in the correct format. ({issue}`12031`) + +## Hive connector + +* Improve performance of inserts by automatically scaling the number of writers + within a worker node. ({issue}`13111`) +* Improve performance of S3 Select when using CSV files as an input. ({issue}`13754`) +* Fix error where the S3 KMS key is not searched in the proper AWS region when + S3 client-side encryption is used. ({issue}`13715`) + +## Iceberg connector + +* Improve performance of inserts by automatically scaling the number of writers + within a worker node. ({issue}`13111`) +* Fix creating metadata and manifest files with a URL-encoded name on S3 when + the metadata location has trailing slashes. ({issue}`13759`) + +## MariaDB connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## Memory connector + +* Add support for table and column comments. ({issue}`13936`) + +## MongoDB connector + +* Fix query failure when filtering on columns of `json` type. ({issue}`13536`) + +## MySQL connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## Oracle connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## Phoenix connector + +* Fix query failure when adding, renaming, or dropping a column with a name + which matches a reserved keyword or has special characters which require it to + be quoted. ({issue}`13839`) + +## PostgreSQL connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## Prometheus connector + +* Add support for case-insensitive table name matching with the + `prometheus.case-insensitive-name-matching` configuration property. ({issue}`8740`) + +## Redshift connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## SingleStore (MemSQL) connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## SQL Server connector + +* Fix query failure when renaming or dropping a column with a name which matches + a reserved keyword or has special characters which require it to be quoted. ({issue}`13839`) + +## SPI + +* Add support for dynamic function resolution. ({issue}`8`) +* Rename `LIKE_PATTERN_FUNCTION_NAME` to `LIKE_FUNCTION_NAME` in + `StandardFunctions`. ({issue}`13965`) +* Remove the `listAllRoleGrants` method from `ConnectorMetadata`. ({issue}`11341`) diff --git a/docs/src/main/sphinx/sql/merge.rst b/docs/src/main/sphinx/sql/merge.rst index d75f41167006..ea3597048c78 100644 --- a/docs/src/main/sphinx/sql/merge.rst +++ b/docs/src/main/sphinx/sql/merge.rst @@ -93,5 +93,7 @@ table row:: Limitations ----------- -Some connectors have limited or no support for ``MERGE``. -See connector documentation for more details. +Any connector can be used as a source table for a ``MERGE`` statement. +Only connectors which support the ``MERGE`` statement can be the target of a +merge operation. See the :doc:`connector documentation ` for more +information. diff --git a/lib/trino-array/pom.xml b/lib/trino-array/pom.xml index 74e93575db48..36b8f586e9c8 100644 --- a/lib/trino-array/pom.xml +++ b/lib/trino-array/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-collect/pom.xml b/lib/trino-collect/pom.xml index d90bcc24a5e6..e151b9937ec7 100644 --- a/lib/trino-collect/pom.xml +++ b/lib/trino-collect/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-filesystem/pom.xml b/lib/trino-filesystem/pom.xml index e470db291266..d9bcc70b08f8 100644 --- a/lib/trino-filesystem/pom.xml +++ b/lib/trino-filesystem/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-geospatial-toolkit/pom.xml b/lib/trino-geospatial-toolkit/pom.xml index f8cc4425b447..1d90c54c40b9 100644 --- a/lib/trino-geospatial-toolkit/pom.xml +++ b/lib/trino-geospatial-toolkit/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-hadoop-toolkit/pom.xml b/lib/trino-hadoop-toolkit/pom.xml index f217060b84ad..1946b07c73db 100644 --- a/lib/trino-hadoop-toolkit/pom.xml +++ b/lib/trino-hadoop-toolkit/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-hdfs/pom.xml b/lib/trino-hdfs/pom.xml index abf1a8b21bf8..27e23ff571f4 100644 --- a/lib/trino-hdfs/pom.xml +++ b/lib/trino-hdfs/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-matching/pom.xml b/lib/trino-matching/pom.xml index 8d7fcdbf7982..b51e89d09029 100644 --- a/lib/trino-matching/pom.xml +++ b/lib/trino-matching/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-matching/src/main/java/io/trino/matching/Captures.java b/lib/trino-matching/src/main/java/io/trino/matching/Captures.java index 27bb02fb1a1b..e587a8a3ba2d 100644 --- a/lib/trino-matching/src/main/java/io/trino/matching/Captures.java +++ b/lib/trino-matching/src/main/java/io/trino/matching/Captures.java @@ -45,9 +45,7 @@ public Captures addAll(Captures other) if (this == NIL) { return other; } - else { - return new Captures(capture, value, tail.addAll(other)); - } + return new Captures(capture, value, tail.addAll(other)); } @SuppressWarnings("unchecked cast") @@ -56,12 +54,10 @@ public T get(Capture capture) if (this.equals(NIL)) { throw new NoSuchElementException("Requested value for unknown Capture. Was it registered in the Pattern?"); } - else if (this.capture.equals(capture)) { + if (this.capture.equals(capture)) { return (T) value; } - else { - return tail.get(capture); - } + return tail.get(capture); } @Override diff --git a/lib/trino-matching/src/main/java/io/trino/matching/Pattern.java b/lib/trino-matching/src/main/java/io/trino/matching/Pattern.java index fe90541d06b3..708b89500a21 100644 --- a/lib/trino-matching/src/main/java/io/trino/matching/Pattern.java +++ b/lib/trino-matching/src/main/java/io/trino/matching/Pattern.java @@ -116,9 +116,7 @@ public final Stream match(Object object, Captures captures, C context return previous.get().match(object, captures, context) .flatMap(match -> accept(object, match.captures(), context)); } - else { - return accept(object, captures, context); - } + return accept(object, captures, context); } @Override diff --git a/lib/trino-memory-context/pom.xml b/lib/trino-memory-context/pom.xml index e1b78648916d..35249c18c8a6 100644 --- a/lib/trino-memory-context/pom.xml +++ b/lib/trino-memory-context/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-orc/pom.xml b/lib/trino-orc/pom.xml index 1d4ea126851d..76fd93af18aa 100644 --- a/lib/trino-orc/pom.xml +++ b/lib/trino-orc/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-orc/src/main/java/io/trino/orc/OrcWriteValidation.java b/lib/trino-orc/src/main/java/io/trino/orc/OrcWriteValidation.java index 162c819c6a3e..1d7543a5eb30 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/OrcWriteValidation.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/OrcWriteValidation.java @@ -101,6 +101,7 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static java.lang.Math.floorDiv; import static java.lang.String.format; @@ -622,7 +623,7 @@ else if (type instanceof CharType) { fieldExtractor = ignored -> ImmutableList.of(); fieldBuilders = ImmutableList.of(); } - else if (VARBINARY.equals(type)) { + else if (VARBINARY.equals(type) || UUID.equals(type)) { statisticsBuilder = new BinaryStatisticsBuilder(); fieldExtractor = ignored -> ImmutableList.of(); fieldBuilders = ImmutableList.of(); diff --git a/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/Checkpoints.java b/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/Checkpoints.java index c7f932eb94e9..81498c199677 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/Checkpoints.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/Checkpoints.java @@ -144,7 +144,7 @@ public static StreamCheckpoint getDictionaryStreamCheckpoint(StreamId streamId, if (columnEncoding == DICTIONARY_V2) { return new LongStreamV2Checkpoint(0, createInputStreamCheckpoint(0, 0)); } - else if (columnEncoding == DICTIONARY) { + if (columnEncoding == DICTIONARY) { return new LongStreamV1Checkpoint(0, createInputStreamCheckpoint(0, 0)); } } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/InputStreamCheckpoint.java b/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/InputStreamCheckpoint.java index bd55e12e9416..178ae6985d3d 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/InputStreamCheckpoint.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/checkpoint/InputStreamCheckpoint.java @@ -32,9 +32,7 @@ public static long createInputStreamCheckpoint(boolean compressed, ColumnPositio if (compressed) { return createInputStreamCheckpoint(positionsList.nextPosition(), positionsList.nextPosition()); } - else { - return createInputStreamCheckpoint(0, positionsList.nextPosition()); - } + return createInputStreamCheckpoint(0, positionsList.nextPosition()); } public static long createInputStreamCheckpoint(int compressedBlockOffset, int decompressedOffset) @@ -58,9 +56,7 @@ public static List createInputStreamPositionList(boolean compressed, lo if (compressed) { return ImmutableList.of(decodeCompressedBlockOffset(inputStreamCheckpoint), decodeDecompressedOffset(inputStreamCheckpoint)); } - else { - return ImmutableList.of(decodeDecompressedOffset(inputStreamCheckpoint)); - } + return ImmutableList.of(decodeDecompressedOffset(inputStreamCheckpoint)); } public static String inputStreamCheckpointToString(long inputStreamCheckpoint) diff --git a/lib/trino-orc/src/main/java/io/trino/orc/metadata/OrcType.java b/lib/trino-orc/src/main/java/io/trino/orc/metadata/OrcType.java index d6bb4bc42ac8..6600e284a88a 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/metadata/OrcType.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/metadata/OrcType.java @@ -13,6 +13,7 @@ */ package io.trino.orc.metadata; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.spi.TrinoException; @@ -29,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -187,8 +189,12 @@ public String toString() .toString(); } - private static List toOrcType(int nextFieldTypeIndex, Type type) + private static List toOrcType(int nextFieldTypeIndex, Type type, Optional>> additionalTypeMapping) { + Optional> orcType = additionalTypeMapping.flatMap(mapping -> mapping.apply(type)).map(ImmutableList::of); + if (orcType.isPresent()) { + return orcType.get(); + } if (BOOLEAN.equals(type)) { return ImmutableList.of(new OrcType(OrcTypeKind.BOOLEAN)); } @@ -237,10 +243,10 @@ private static List toOrcType(int nextFieldTypeIndex, Type type) return ImmutableList.of(new OrcType(OrcTypeKind.DECIMAL, decimalType.getPrecision(), decimalType.getScale())); } if (type instanceof ArrayType) { - return createOrcArrayType(nextFieldTypeIndex, type.getTypeParameters().get(0)); + return createOrcArrayType(nextFieldTypeIndex, type.getTypeParameters().get(0), additionalTypeMapping); } if (type instanceof MapType) { - return createOrcMapType(nextFieldTypeIndex, type.getTypeParameters().get(0), type.getTypeParameters().get(1)); + return createOrcMapType(nextFieldTypeIndex, type.getTypeParameters().get(0), type.getTypeParameters().get(1), additionalTypeMapping); } if (type instanceof RowType) { List fieldNames = new ArrayList<>(); @@ -250,15 +256,15 @@ private static List toOrcType(int nextFieldTypeIndex, Type type) } List fieldTypes = type.getTypeParameters(); - return createOrcRowType(nextFieldTypeIndex, fieldNames, fieldTypes); + return createOrcRowType(nextFieldTypeIndex, fieldNames, fieldTypes, additionalTypeMapping); } throw new TrinoException(NOT_SUPPORTED, format("Unsupported Hive type: %s", type)); } - private static List createOrcArrayType(int nextFieldTypeIndex, Type itemType) + private static List createOrcArrayType(int nextFieldTypeIndex, Type itemType, Optional>> additionalTypeMapping) { nextFieldTypeIndex++; - List itemTypes = toOrcType(nextFieldTypeIndex, itemType); + List itemTypes = toOrcType(nextFieldTypeIndex, itemType, additionalTypeMapping); List orcTypes = new ArrayList<>(); orcTypes.add(new OrcType(OrcTypeKind.LIST, ImmutableList.of(new OrcColumnId(nextFieldTypeIndex)), ImmutableList.of("item"))); @@ -266,11 +272,11 @@ private static List createOrcArrayType(int nextFieldTypeIndex, Type ite return orcTypes; } - private static List createOrcMapType(int nextFieldTypeIndex, Type keyType, Type valueType) + private static List createOrcMapType(int nextFieldTypeIndex, Type keyType, Type valueType, Optional>> additionalTypeMapping) { nextFieldTypeIndex++; - List keyTypes = toOrcType(nextFieldTypeIndex, keyType); - List valueTypes = toOrcType(nextFieldTypeIndex + keyTypes.size(), valueType); + List keyTypes = toOrcType(nextFieldTypeIndex, keyType, additionalTypeMapping); + List valueTypes = toOrcType(nextFieldTypeIndex + keyTypes.size(), valueType, additionalTypeMapping); List orcTypes = new ArrayList<>(); orcTypes.add(new OrcType( @@ -284,17 +290,23 @@ private static List createOrcMapType(int nextFieldTypeIndex, Type keyTy public static ColumnMetadata createRootOrcType(List fieldNames, List fieldTypes) { - return new ColumnMetadata<>(createOrcRowType(0, fieldNames, fieldTypes)); + return createRootOrcType(fieldNames, fieldTypes, Optional.empty()); + } + + @VisibleForTesting + public static ColumnMetadata createRootOrcType(List fieldNames, List fieldTypes, Optional>> additionalTypeMapping) + { + return new ColumnMetadata<>(createOrcRowType(0, fieldNames, fieldTypes, additionalTypeMapping)); } - private static List createOrcRowType(int nextFieldTypeIndex, List fieldNames, List fieldTypes) + private static List createOrcRowType(int nextFieldTypeIndex, List fieldNames, List fieldTypes, Optional>> additionalTypeMapping) { nextFieldTypeIndex++; List fieldTypeIndexes = new ArrayList<>(); List> fieldTypesList = new ArrayList<>(); for (Type fieldType : fieldTypes) { fieldTypeIndexes.add(new OrcColumnId(nextFieldTypeIndex)); - List fieldOrcTypes = toOrcType(nextFieldTypeIndex, fieldType); + List fieldOrcTypes = toOrcType(nextFieldTypeIndex, fieldType, additionalTypeMapping); fieldTypesList.add(fieldOrcTypes); nextFieldTypeIndex += fieldOrcTypes.size(); } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/metadata/statistics/StringStatisticsBuilder.java b/lib/trino-orc/src/main/java/io/trino/orc/metadata/statistics/StringStatisticsBuilder.java index 6198a41e56a4..320eab940e9e 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/metadata/statistics/StringStatisticsBuilder.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/metadata/statistics/StringStatisticsBuilder.java @@ -187,9 +187,7 @@ private Slice computeStringMinMax(Slice minOrMax, boolean isMin) if (isMin) { return StringCompactor.truncateMin(minOrMax, stringStatisticsLimitInBytes); } - else { - return StringCompactor.truncateMax(minOrMax, stringStatisticsLimitInBytes); - } + return StringCompactor.truncateMax(minOrMax, stringStatisticsLimitInBytes); } // Do not hold the entire slice where the actual stats could be small diff --git a/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java b/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java index a0e42188ce2c..63b9205a3b70 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java @@ -21,7 +21,10 @@ import io.trino.orc.OrcReader.FieldMapperFactory; import io.trino.spi.type.TimeType; import io.trino.spi.type.Type; +import io.trino.spi.type.UuidType; +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.orc.metadata.OrcType.OrcTypeKind.BINARY; import static io.trino.orc.metadata.OrcType.OrcTypeKind.LONG; import static io.trino.orc.reader.ReaderUtils.invalidStreamType; import static io.trino.spi.type.IntegerType.INTEGER; @@ -29,6 +32,8 @@ public final class ColumnReaders { + public static final String ICEBERG_BINARY_TYPE = "iceberg.binary-type"; + private ColumnReaders() {} public static ColumnReader createColumnReader( @@ -47,6 +52,14 @@ public static ColumnReader createColumnReader( } return new TimeColumnReader(type, column, memoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); } + if (type instanceof UuidType) { + checkArgument(column.getColumnType() == BINARY, "UUID type can only be read from BINARY column but got " + column); + checkArgument( + "UUID".equals(column.getAttributes().get(ICEBERG_BINARY_TYPE)), + "Expected ORC column for UUID data to be annotated with %s=UUID: %s", + ICEBERG_BINARY_TYPE, column); + return new UuidColumnReader(column); + } switch (column.getColumnType()) { case BOOLEAN: diff --git a/lib/trino-orc/src/main/java/io/trino/orc/reader/UuidColumnReader.java b/lib/trino-orc/src/main/java/io/trino/orc/reader/UuidColumnReader.java new file mode 100644 index 000000000000..18cda4c49851 --- /dev/null +++ b/lib/trino-orc/src/main/java/io/trino/orc/reader/UuidColumnReader.java @@ -0,0 +1,271 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.orc.reader; + +import io.airlift.units.DataSize; +import io.trino.orc.OrcColumn; +import io.trino.orc.OrcCorruptionException; +import io.trino.orc.metadata.ColumnEncoding; +import io.trino.orc.metadata.ColumnMetadata; +import io.trino.orc.stream.BooleanInputStream; +import io.trino.orc.stream.ByteArrayInputStream; +import io.trino.orc.stream.InputStreamSource; +import io.trino.orc.stream.InputStreamSources; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.Int128ArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import org.openjdk.jol.info.ClassLayout; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.time.ZoneId; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static io.trino.orc.metadata.Stream.StreamKind.DATA; +import static io.trino.orc.metadata.Stream.StreamKind.PRESENT; +import static io.trino.orc.stream.MissingInputStreamSource.missingStreamSource; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class UuidColumnReader + implements ColumnReader +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(UuidColumnReader.class).instanceSize(); + private static final int ONE_GIGABYTE = toIntExact(DataSize.of(1, GIGABYTE).toBytes()); + + private static final VarHandle LONG_ARRAY_HANDLE = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); + private final OrcColumn column; + + private int readOffset; + private int nextBatchSize; + + private InputStreamSource presentStreamSource = missingStreamSource(BooleanInputStream.class); + @Nullable + private BooleanInputStream presentStream; + + private InputStreamSource dataByteSource = missingStreamSource(ByteArrayInputStream.class); + @Nullable + private ByteArrayInputStream dataStream; + + private boolean rowGroupOpen; + + public UuidColumnReader(OrcColumn column) + { + this.column = requireNonNull(column, "column is null"); + } + + @Override + public void prepareNextRead(int batchSize) + { + readOffset += nextBatchSize; + nextBatchSize = batchSize; + } + + @Override + public Block readBlock() + throws IOException + { + if (!rowGroupOpen) { + openRowGroup(); + } + + if (readOffset > 0) { + skipToReadOffset(); + readOffset = 0; + } + + if (dataStream == null) { + if (presentStream == null) { + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); + } + // since dataStream is null, all values are null + presentStream.skip(nextBatchSize); + Block nullValueBlock = createAllNullsBlock(); + nextBatchSize = 0; + return nullValueBlock; + } + + boolean[] isNullVector = null; + int nullCount = 0; + if (presentStream != null) { + isNullVector = new boolean[nextBatchSize]; + nullCount = presentStream.getUnsetBits(nextBatchSize, isNullVector); + if (nullCount == nextBatchSize) { + // all nulls + Block nullValueBlock = createAllNullsBlock(); + nextBatchSize = 0; + return nullValueBlock; + } + + if (nullCount == 0) { + isNullVector = null; + } + } + + int numberOfLongValues = toIntExact(nextBatchSize * 2L); + int totalByteLength = toIntExact((long) numberOfLongValues * Long.BYTES); + + int currentBatchSize = nextBatchSize; + nextBatchSize = 0; + if (totalByteLength == 0) { + return new Int128ArrayBlock(currentBatchSize, Optional.empty(), new long[0]); + } + if (totalByteLength > ONE_GIGABYTE) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, + format("Values in column \"%s\" are too large to process for Trino. %s column values are larger than 1GB [%s]", column.getPath(), nextBatchSize, column.getOrcDataSourceId())); + } + if (dataStream == null) { + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); + } + + if (isNullVector == null) { + long[] values = readNonNullLongs(numberOfLongValues); + return new Int128ArrayBlock(currentBatchSize, Optional.empty(), values); + } + + int nonNullCount = currentBatchSize - nullCount; + long[] values = readNullableLongs(isNullVector, nonNullCount); + return new Int128ArrayBlock(currentBatchSize, Optional.of(isNullVector), values); + } + + @Override + public void startStripe(ZoneId fileTimeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) + { + presentStreamSource = missingStreamSource(BooleanInputStream.class); + dataByteSource = missingStreamSource(ByteArrayInputStream.class); + + readOffset = 0; + nextBatchSize = 0; + + presentStream = null; + dataStream = null; + + rowGroupOpen = false; + } + + @Override + public void startRowGroup(InputStreamSources dataStreamSources) + { + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataByteSource = dataStreamSources.getInputStreamSource(column, DATA, ByteArrayInputStream.class); + + readOffset = 0; + nextBatchSize = 0; + + presentStream = null; + dataStream = null; + + rowGroupOpen = false; + } + + @Override + public String toString() + { + return toStringHelper(this) + .addValue(column) + .toString(); + } + + @Override + public void close() + { + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE; + } + + private void skipToReadOffset() + throws IOException + { + int dataReadOffset = readOffset; + if (presentStream != null) { + // skip ahead the present bit reader, but count the set bits + // and use this as the skip size for the dataStream + dataReadOffset = presentStream.countBitsSet(readOffset); + } + if (dataReadOffset > 0) { + if (dataStream == null) { + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); + } + // dataReadOffset deals with positions. Each position is 2 longs in the dataStream. + long dataSkipSize = dataReadOffset * 2L * Long.BYTES; + + dataStream.skip(dataSkipSize); + } + } + + private long[] readNullableLongs(boolean[] isNullVector, int nonNullCount) + throws IOException + { + byte[] data = new byte[nonNullCount * 2 * Long.BYTES]; + + dataStream.next(data, 0, data.length); + + int[] offsets = new int[isNullVector.length]; + int offsetPosition = 0; + for (int i = 0; i < isNullVector.length; i++) { + offsets[i] = Math.min(offsetPosition * 2 * Long.BYTES, data.length - Long.BYTES * 2); + offsetPosition += isNullVector[i] ? 0 : 1; + } + + long[] values = new long[isNullVector.length * 2]; + + for (int i = 0; i < isNullVector.length; i++) { + int isNonNull = isNullVector[i] ? 0 : 1; + values[i * 2] = (long) LONG_ARRAY_HANDLE.get(data, offsets[i]) * isNonNull; + values[i * 2 + 1] = (long) LONG_ARRAY_HANDLE.get(data, offsets[i] + Long.BYTES) * isNonNull; + } + return values; + } + + private long[] readNonNullLongs(int valueCount) + throws IOException + { + byte[] data = new byte[valueCount * Long.BYTES]; + + dataStream.next(data, 0, data.length); + + long[] values = new long[valueCount]; + for (int i = 0; i < valueCount; i++) { + values[i] = (long) LONG_ARRAY_HANDLE.get(data, i * Long.BYTES); + } + return values; + } + + private RunLengthEncodedBlock createAllNullsBlock() + { + return new RunLengthEncodedBlock(new Int128ArrayBlock(1, Optional.of(new boolean[] {true}), new long[2]), nextBatchSize); + } + + private void openRowGroup() + throws IOException + { + presentStream = presentStreamSource.openStream(); + dataStream = dataByteSource.openStream(); + + rowGroupOpen = true; + } +} diff --git a/lib/trino-orc/src/main/java/io/trino/orc/stream/DecimalOutputStream.java b/lib/trino-orc/src/main/java/io/trino/orc/stream/DecimalOutputStream.java index 5771a9491380..908dfd5c78b3 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/stream/DecimalOutputStream.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/stream/DecimalOutputStream.java @@ -72,10 +72,8 @@ public void writeUnscaledValue(Int128 decimal) buffer.write((byte) lowBits); return; } - else { - buffer.write((byte) (0x80 | (lowBits & 0x7f))); - lowBits >>>= 7; - } + buffer.write((byte) (0x80 | (lowBits & 0x7f))); + lowBits >>>= 7; } value = value.shiftRight(63); } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/stream/LongDecode.java b/lib/trino-orc/src/main/java/io/trino/orc/stream/LongDecode.java index dd92dd7347b4..4c08e8c87c21 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/stream/LongDecode.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/stream/LongDecode.java @@ -49,30 +49,28 @@ public static int decodeBitWidth(int n) if (n >= ONE.ordinal() && n <= TWENTY_FOUR.ordinal()) { return n + 1; } - else if (n == TWENTY_SIX.ordinal()) { + if (n == TWENTY_SIX.ordinal()) { return 26; } - else if (n == TWENTY_EIGHT.ordinal()) { + if (n == TWENTY_EIGHT.ordinal()) { return 28; } - else if (n == THIRTY.ordinal()) { + if (n == THIRTY.ordinal()) { return 30; } - else if (n == THIRTY_TWO.ordinal()) { + if (n == THIRTY_TWO.ordinal()) { return 32; } - else if (n == FORTY.ordinal()) { + if (n == FORTY.ordinal()) { return 40; } - else if (n == FORTY_EIGHT.ordinal()) { + if (n == FORTY_EIGHT.ordinal()) { return 48; } - else if (n == FIFTY_SIX.ordinal()) { + if (n == FIFTY_SIX.ordinal()) { return 56; } - else { - return 64; - } + return 64; } /** @@ -87,30 +85,28 @@ public static int getClosestFixedBits(int width) if (width >= 1 && width <= 24) { return width; } - else if (width > 24 && width <= 26) { + if (width > 24 && width <= 26) { return 26; } - else if (width > 26 && width <= 28) { + if (width > 26 && width <= 28) { return 28; } - else if (width > 28 && width <= 30) { + if (width > 28 && width <= 30) { return 30; } - else if (width > 30 && width <= 32) { + if (width > 30 && width <= 32) { return 32; } - else if (width > 32 && width <= 40) { + if (width > 32 && width <= 40) { return 40; } - else if (width > 40 && width <= 48) { + if (width > 40 && width <= 48) { return 48; } - else if (width > 48 && width <= 56) { + if (width > 48 && width <= 56) { return 56; } - else { - return 64; - } + return 64; } public static long readSignedVInt(OrcInputStream inputStream) @@ -144,9 +140,7 @@ public static long readVInt(boolean signed, OrcInputStream inputStream) if (signed) { return readSignedVInt(inputStream); } - else { - return readUnsignedVInt(inputStream); - } + return readUnsignedVInt(inputStream); } public static long zigzagDecode(long value) @@ -170,10 +164,8 @@ private static void writeVLongUnsigned(SliceOutput output, long value) output.write((byte) value); return; } - else { - output.write((byte) (0x80 | (value & 0x7f))); - value >>>= 7; - } + output.write((byte) (0x80 | (value & 0x7f))); + value >>>= 7; } } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/stream/LongOutputStreamV2.java b/lib/trino-orc/src/main/java/io/trino/orc/stream/LongOutputStreamV2.java index 7246d88e0e7f..33980a6ee142 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/stream/LongOutputStreamV2.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/stream/LongOutputStreamV2.java @@ -795,10 +795,8 @@ static void writeVulong(SliceOutput output, long value) output.write((byte) value); return; } - else { - output.write((byte) (0x80 | (value & 0x7f))); - value >>>= 7; - } + output.write((byte) (0x80 | (value & 0x7f))); + value >>>= 7; } } @@ -870,30 +868,28 @@ static int getClosestFixedBits(int n) if (n >= 1 && n <= 24) { return n; } - else if (n > 24 && n <= 26) { + if (n > 24 && n <= 26) { return 26; } - else if (n > 26 && n <= 28) { + if (n > 26 && n <= 28) { return 28; } - else if (n > 28 && n <= 30) { + if (n > 28 && n <= 30) { return 30; } - else if (n > 30 && n <= 32) { + if (n > 30 && n <= 32) { return 32; } - else if (n > 32 && n <= 40) { + if (n > 32 && n <= 40) { return 40; } - else if (n > 40 && n <= 48) { + if (n > 40 && n <= 48) { return 48; } - else if (n > 48 && n <= 56) { + if (n > 48 && n <= 56) { return 56; } - else { - return 64; - } + return 64; } public static int getClosestAlignedFixedBits(int n) @@ -901,36 +897,34 @@ public static int getClosestAlignedFixedBits(int n) if (n == 0 || n == 1) { return 1; } - else if (n > 1 && n <= 2) { + if (n > 1 && n <= 2) { return 2; } - else if (n > 2 && n <= 4) { + if (n > 2 && n <= 4) { return 4; } - else if (n > 4 && n <= 8) { + if (n > 4 && n <= 8) { return 8; } - else if (n > 8 && n <= 16) { + if (n > 8 && n <= 16) { return 16; } - else if (n > 16 && n <= 24) { + if (n > 16 && n <= 24) { return 24; } - else if (n > 24 && n <= 32) { + if (n > 24 && n <= 32) { return 32; } - else if (n > 32 && n <= 40) { + if (n > 32 && n <= 40) { return 40; } - else if (n > 40 && n <= 48) { + if (n > 40 && n <= 48) { return 48; } - else if (n > 48 && n <= 56) { + if (n > 48 && n <= 56) { return 56; } - else { - return 64; - } + return 64; } enum FixedBitSizes @@ -955,30 +949,28 @@ static int encodeBitWidth(int n) if (n >= 1 && n <= 24) { return n - 1; } - else if (n > 24 && n <= 26) { + if (n > 24 && n <= 26) { return FixedBitSizes.TWENTY_SIX.ordinal(); } - else if (n > 26 && n <= 28) { + if (n > 26 && n <= 28) { return FixedBitSizes.TWENTY_EIGHT.ordinal(); } - else if (n > 28 && n <= 30) { + if (n > 28 && n <= 30) { return FixedBitSizes.THIRTY.ordinal(); } - else if (n > 30 && n <= 32) { + if (n > 30 && n <= 32) { return FixedBitSizes.THIRTY_TWO.ordinal(); } - else if (n > 32 && n <= 40) { + if (n > 32 && n <= 40) { return FixedBitSizes.FORTY.ordinal(); } - else if (n > 40 && n <= 48) { + if (n > 40 && n <= 48) { return FixedBitSizes.FORTY_EIGHT.ordinal(); } - else if (n > 48 && n <= 56) { + if (n > 48 && n <= 56) { return FixedBitSizes.FIFTY_SIX.ordinal(); } - else { - return FixedBitSizes.SIXTY_FOUR.ordinal(); - } + return FixedBitSizes.SIXTY_FOUR.ordinal(); } /** @@ -989,30 +981,28 @@ static int decodeBitWidth(int n) if (n >= FixedBitSizes.ONE.ordinal() && n <= FixedBitSizes.TWENTY_FOUR.ordinal()) { return n + 1; } - else if (n == FixedBitSizes.TWENTY_SIX.ordinal()) { + if (n == FixedBitSizes.TWENTY_SIX.ordinal()) { return 26; } - else if (n == FixedBitSizes.TWENTY_EIGHT.ordinal()) { + if (n == FixedBitSizes.TWENTY_EIGHT.ordinal()) { return 28; } - else if (n == FixedBitSizes.THIRTY.ordinal()) { + if (n == FixedBitSizes.THIRTY.ordinal()) { return 30; } - else if (n == FixedBitSizes.THIRTY_TWO.ordinal()) { + if (n == FixedBitSizes.THIRTY_TWO.ordinal()) { return 32; } - else if (n == FixedBitSizes.FORTY.ordinal()) { + if (n == FixedBitSizes.FORTY.ordinal()) { return 40; } - else if (n == FixedBitSizes.FORTY_EIGHT.ordinal()) { + if (n == FixedBitSizes.FORTY_EIGHT.ordinal()) { return 48; } - else if (n == FixedBitSizes.FIFTY_SIX.ordinal()) { + if (n == FixedBitSizes.FIFTY_SIX.ordinal()) { return 56; } - else { - return 64; - } + return 64; } void writeInts(long[] input, int offset, int length, int bitSize, SliceOutput output) diff --git a/lib/trino-orc/src/main/java/io/trino/orc/stream/ValueStreams.java b/lib/trino-orc/src/main/java/io/trino/orc/stream/ValueStreams.java index 27445d32b15b..a5726d28b724 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/stream/ValueStreams.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/stream/ValueStreams.java @@ -135,11 +135,9 @@ private static ValueInputStream createLongStream( if (encoding == DIRECT_V2 || encoding == DICTIONARY_V2) { return new LongInputStreamV2(inputStream, signed, false); } - else if (encoding == DIRECT || encoding == DICTIONARY) { + if (encoding == DIRECT || encoding == DICTIONARY) { return new LongInputStreamV1(inputStream, signed); } - else { - throw new IllegalArgumentException("Unsupported encoding for long stream: " + encoding); - } + throw new IllegalArgumentException("Unsupported encoding for long stream: " + encoding); } } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java index 80cc998df791..d9b75cdc331e 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java @@ -103,9 +103,7 @@ public boolean contains(Block block, int position) if (block.isNull(position)) { return containsNullElement; } - else { - return blockPositionByHash.get(getHashPositionOfElement(block, position)) != EMPTY_SLOT; - } + return blockPositionByHash.get(getHashPositionOfElement(block, position)) != EMPTY_SLOT; } public int putIfAbsent(Block block, int position) diff --git a/lib/trino-orc/src/test/java/io/trino/orc/AbstractTestOrcReader.java b/lib/trino-orc/src/test/java/io/trino/orc/AbstractTestOrcReader.java index 3fac1e7ef68a..517f64094376 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/AbstractTestOrcReader.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/AbstractTestOrcReader.java @@ -37,6 +37,7 @@ import java.util.List; import java.util.Map; import java.util.Random; +import java.util.UUID; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.cycle; @@ -60,11 +61,13 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.testing.DateTimeTestingUtils.sqlTimestampOf; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Collections.nCopies; +import static java.util.UUID.randomUUID; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; @@ -507,6 +510,27 @@ public void testEmptyBinarySequence() tester.testRoundTrip(VARBINARY, nCopies(30_000, new SqlVarbinary(new byte[0]))); } + @Test + public void testUuidDirectSequence() + throws Exception + { + tester.testRoundTrip( + UUID, + intsBetween(0, 30_000).stream() + .map(i -> randomUUID()) + .collect(toList())); + } + + @Test + public void testUuidDictionarySequence() + throws Exception + { + tester.testRoundTrip( + UUID, ImmutableList.copyOf(limit(cycle(ImmutableList.of(1, 3, 5, 7, 11, 13, 17)), 30_000)).stream() + .map(i -> new UUID(i, i)) + .collect(toList())); + } + private static Iterable skipEvery(int n, Iterable iterable) { return () -> new AbstractIterator<>() diff --git a/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java b/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java index 83d84dea6e38..0aa631aae234 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java @@ -22,7 +22,9 @@ import io.trino.spi.type.DecimalType; import io.trino.spi.type.SqlDecimal; import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.SqlVarbinary; import io.trino.spi.type.Type; +import io.trino.type.UuidOperators; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -46,6 +48,7 @@ import java.util.Iterator; import java.util.List; import java.util.Random; +import java.util.UUID; import java.util.concurrent.TimeUnit; import static com.google.common.io.MoreFiles.deleteRecursively; @@ -66,6 +69,8 @@ import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static java.nio.file.Files.createTempDirectory; @@ -330,6 +335,42 @@ public Object readLineitem(LineitemBenchmarkData data) return pages; } + @Benchmark + public Object readUuidNoNull(UuidNoNullBenchmarkData data) + throws Exception + { + try (OrcRecordReader recordReader = data.createRecordReader()) { + return readFirstColumn(recordReader); + } + } + + @Benchmark + public Object readUuidWithNull(UuidWithNullBenchmarkData data) + throws Exception + { + try (OrcRecordReader recordReader = data.createRecordReader()) { + return readFirstColumn(recordReader); + } + } + + @Benchmark + public Object readVarbinaryUuidNoNull(VarbinaryUuidNoNullBenchmarkData data) + throws Exception + { + try (OrcRecordReader recordReader = data.createRecordReader()) { + return readFirstColumn(recordReader); + } + } + + @Benchmark + public Object readVarbinaryUuidWithNull(VarbinaryUuidWithNullBenchmarkData data) + throws Exception + { + try (OrcRecordReader recordReader = data.createRecordReader()) { + return readFirstColumn(recordReader); + } + } + private Object readFirstColumn(OrcRecordReader recordReader) throws IOException { @@ -416,6 +457,7 @@ public static class AllNullBenchmarkData "varchar", "varbinary", + "uuid" }) private String typeName; @@ -997,6 +1039,100 @@ private Iterator createValues() } } + @State(Thread) + public static class UuidNoNullBenchmarkData + extends BenchmarkData + { + @Setup + public void setup() + throws Exception + { + setup(UUID, createValues()); + } + + private Iterator createValues() + { + List values = new ArrayList<>(); + for (int i = 0; i < ROWS; ++i) { + values.add(java.util.UUID.randomUUID()); + } + return values.iterator(); + } + } + + @State(Thread) + public static class UuidWithNullBenchmarkData + extends BenchmarkData + { + @Setup + public void setup() + throws Exception + { + setup(UUID, createValues()); + } + + private Iterator createValues() + { + List values = new ArrayList<>(); + for (int i = 0; i < ROWS; ++i) { + if (random.nextBoolean()) { + values.add(null); + } + else { + values.add(java.util.UUID.randomUUID()); + } + } + return values.iterator(); + } + } + + @State(Thread) + public static class VarbinaryUuidNoNullBenchmarkData + extends BenchmarkData + { + @Setup + public void setup() + throws Exception + { + setup(VARBINARY, createValues()); + } + + private Iterator createValues() + { + List values = new ArrayList<>(); + for (int i = 0; i < ROWS; ++i) { + values.add(new SqlVarbinary(UuidOperators.uuid().getBytes())); + } + return values.iterator(); + } + } + + @State(Thread) + public static class VarbinaryUuidWithNullBenchmarkData + extends BenchmarkData + { + @Setup + public void setup() + throws Exception + { + setup(VARBINARY, createValues()); + } + + private Iterator createValues() + { + List values = new ArrayList<>(); + for (int i = 0; i < ROWS; ++i) { + if (random.nextBoolean()) { + values.add(null); + } + else { + values.add(new SqlVarbinary(UuidOperators.uuid().getBytes())); + } + } + return values.iterator(); + } + } + private static List createDictionary(Random random) { List dictionary = new ArrayList<>(); diff --git a/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java b/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java index 8ab6da2e549a..f0393b82a350 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java @@ -23,6 +23,7 @@ import io.airlift.slice.Slices; import io.airlift.units.DataSize; import io.trino.hive.orc.OrcConf; +import io.trino.orc.metadata.ColumnMetadata; import io.trino.orc.metadata.CompressionKind; import io.trino.orc.metadata.OrcType; import io.trino.spi.Page; @@ -92,6 +93,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.time.LocalDateTime; import java.time.ZoneOffset; import java.util.ArrayList; @@ -104,6 +106,7 @@ import java.util.Optional; import java.util.Properties; import java.util.Set; +import java.util.UUID; import java.util.stream.IntStream; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -124,6 +127,8 @@ import static io.trino.orc.metadata.CompressionKind.SNAPPY; import static io.trino.orc.metadata.CompressionKind.ZLIB; import static io.trino.orc.metadata.CompressionKind.ZSTD; +import static io.trino.orc.metadata.OrcType.OrcTypeKind.BINARY; +import static io.trino.orc.reader.ColumnReaders.ICEBERG_BINARY_TYPE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.Chars.truncateToLengthAndTrimSpaces; @@ -148,6 +153,8 @@ import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; import static io.trino.spi.type.Timestamps.roundDiv; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.Varchars.truncateToLength; import static io.trino.testing.DateTimeTestingUtils.sqlTimestampOf; @@ -424,7 +431,7 @@ private void assertRoundTrip(Type writeType, Type readType, List writeValues, { OrcWriterStats stats = new OrcWriterStats(); for (CompressionKind compression : compressions) { - boolean hiveSupported = (compression != LZ4) && (compression != ZSTD) && !isTimestampTz(writeType) && !isTimestampTz(readType); + boolean hiveSupported = (compression != LZ4) && (compression != ZSTD) && !isTimestampTz(writeType) && !isTimestampTz(readType) && !isUuid(writeType) && !isUuid(readType); for (Format format : formats) { // write Hive, read Trino @@ -577,6 +584,10 @@ else if (type.equals(DOUBLE)) { assertEquals(actualDouble, expectedDouble, 0.001); } } + else if (type.equals(UUID)) { + UUID actualUUID = java.util.UUID.fromString((String) actual); + assertEquals(actualUUID, expected); + } else if (!Objects.equals(actual, expected)) { assertEquals(actual, expected); } @@ -635,11 +646,25 @@ public static void writeOrcColumnTrino(File outputFile, CompressionKind compress List columnNames = ImmutableList.of("test"); List types = ImmutableList.of(type); + ColumnMetadata orcType = OrcType.createRootOrcType(columnNames, types, Optional.of(mappedType -> { + if (UUID.equals(mappedType)) { + return Optional.of(new OrcType( + BINARY, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(ICEBERG_BINARY_TYPE, "UUID"))); + } + return Optional.empty(); + })); + OrcWriter writer = new OrcWriter( new OutputStreamOrcDataSink(new FileOutputStream(outputFile)), ImmutableList.of("test"), types, - OrcType.createRootOrcType(columnNames, types), + orcType, compression, new OrcWriterOptions(), ImmutableMap.of(), @@ -694,6 +719,9 @@ else if (type instanceof CharType) { else if (VARBINARY.equals(type)) { type.writeSlice(blockBuilder, Slices.wrappedBuffer(((SqlVarbinary) value).getBytes())); } + else if (UUID.equals(type)) { + type.writeSlice(blockBuilder, javaUuidToTrinoUuid((java.util.UUID) value)); + } else if (DATE.equals(type)) { long days = ((SqlDate) value).getDays(); type.writeLong(blockBuilder, days); @@ -805,7 +833,13 @@ else if (actualValue instanceof ByteWritable) { actualValue = ((ByteWritable) actualValue).get(); } else if (actualValue instanceof BytesWritable) { - actualValue = new SqlVarbinary(((BytesWritable) actualValue).copyBytes()); + if (UUID.equals(type)) { + ByteBuffer bytes = ByteBuffer.wrap(((BytesWritable) actualValue).copyBytes()); + actualValue = new UUID(bytes.getLong(), bytes.getLong()).toString(); + } + else { + actualValue = new SqlVarbinary(((BytesWritable) actualValue).copyBytes()); + } } else if (actualValue instanceof DateWritableV2) { actualValue = new SqlDate(((DateWritableV2) actualValue).getDays()); @@ -984,6 +1018,9 @@ private static ObjectInspector getJavaObjectInspector(Type type) if (type instanceof VarbinaryType) { return javaByteArrayObjectInspector; } + if (type.equals(UUID)) { + return javaByteArrayObjectInspector; + } if (type.equals(DATE)) { return javaDateObjectInspector; } @@ -1053,6 +1090,9 @@ private static Object preprocessWriteValueHive(Type type, Object value) if (type.equals(VARBINARY)) { return ((SqlVarbinary) value).getBytes(); } + if (type.equals(UUID)) { + return javaUuidToTrinoUuid((java.util.UUID) value).getBytes(); + } if (type.equals(DATE)) { return Date.ofEpochDay(((SqlDate) value).getDays()); } @@ -1218,4 +1258,23 @@ private static boolean isTimestampTz(Type type) } return false; } + + private static boolean isUuid(Type type) + { + if (type.equals(UUID)) { + return true; + } + if (type instanceof ArrayType) { + return isUuid(((ArrayType) type).getElementType()); + } + if (type instanceof MapType) { + return isUuid(((MapType) type).getKeyType()) || isUuid(((MapType) type).getValueType()); + } + if (type instanceof RowType) { + return ((RowType) type).getFields().stream() + .map(RowType.Field::getType) + .anyMatch(OrcTester::isUuid); + } + return false; + } } diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryCompressionOptimizer.java b/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryCompressionOptimizer.java index 0745eb0e5635..ed82fa0f3253 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryCompressionOptimizer.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryCompressionOptimizer.java @@ -622,9 +622,7 @@ public OptionalInt tryConvertToDirect(int maxDirectBytes) direct = true; return OptionalInt.of(toIntExact(directBytes)); } - else { - return OptionalInt.empty(); - } + return OptionalInt.empty(); } public boolean isDirect() diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestingOrcPredicate.java b/lib/trino-orc/src/test/java/io/trino/orc/TestingOrcPredicate.java index ce6665bec29d..7e1dc5f18353 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestingOrcPredicate.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestingOrcPredicate.java @@ -59,6 +59,7 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; @@ -88,7 +89,7 @@ public static OrcPredicate createOrcPredicate(Type type, Iterable values) if (REAL.equals(type) || DOUBLE.equals(type)) { return new DoubleOrcPredicate(transform(expectedValues, value -> ((Number) value).doubleValue())); } - if (type instanceof VarbinaryType) { + if (type instanceof VarbinaryType || type.equals(UUID)) { // binary does not have stats return new BasicOrcPredicate<>(expectedValues, Object.class); } diff --git a/lib/trino-orc/src/test/java/io/trino/orc/stream/TestDecimalStream.java b/lib/trino-orc/src/test/java/io/trino/orc/stream/TestDecimalStream.java index ee52897428be..06b7575c529d 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/stream/TestDecimalStream.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/stream/TestDecimalStream.java @@ -263,10 +263,8 @@ private static void writeBigInteger(OutputStream output, BigInteger value) output.write((byte) lowBits); return; } - else { - output.write((byte) (0x80 | (lowBits & 0x7f))); - lowBits >>>= 7; - } + output.write((byte) (0x80 | (lowBits & 0x7f))); + lowBits >>>= 7; } value = value.shiftRight(63); } diff --git a/lib/trino-orc/src/test/java/io/trino/orc/stream/TestLongDecode.java b/lib/trino-orc/src/test/java/io/trino/orc/stream/TestLongDecode.java index 17cd93e22bdf..1852bc12add9 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/stream/TestLongDecode.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/stream/TestLongDecode.java @@ -111,10 +111,8 @@ private static void writeVulong(OutputStream output, long value) output.write((byte) value); return; } - else { - output.write((byte) (0x80 | (value & 0x7f))); - value >>>= 7; - } + output.write((byte) (0x80 | (value & 0x7f))); + value >>>= 7; } } diff --git a/lib/trino-parquet/pom.xml b/lib/trino-parquet/pom.xml index 4435ee2edb03..470a5ab1d443 100644 --- a/lib/trino-parquet/pom.xml +++ b/lib/trino-parquet/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/DecimalColumnReaderFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/DecimalColumnReaderFactory.java index b7c3f5fe1394..d16c896f7517 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/DecimalColumnReaderFactory.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/DecimalColumnReaderFactory.java @@ -25,8 +25,6 @@ public static PrimitiveColumnReader createReader(PrimitiveField field, DecimalTy if (parquetDecimalType.isShort()) { return new ShortDecimalColumnReader(field, parquetDecimalType); } - else { - return new LongDecimalColumnReader(field, parquetDecimalType); - } + return new LongDecimalColumnReader(field, parquetDecimalType); } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java index bcf368e08f14..43aa635b950a 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/PageReader.java @@ -78,27 +78,25 @@ public DataPage readPage() dataPageV1.getDefinitionLevelEncoding(), dataPageV1.getValueEncoding()); } - else { - DataPageV2 dataPageV2 = (DataPageV2) compressedPage; - if (!dataPageV2.isCompressed()) { - return dataPageV2; - } - int uncompressedSize = dataPageV2.getUncompressedSize() - - dataPageV2.getDefinitionLevels().length() - - dataPageV2.getRepetitionLevels().length(); - return new DataPageV2( - dataPageV2.getRowCount(), - dataPageV2.getNullCount(), - dataPageV2.getValueCount(), - dataPageV2.getRepetitionLevels(), - dataPageV2.getDefinitionLevels(), - dataPageV2.getDataEncoding(), - decompress(codec, dataPageV2.getSlice(), uncompressedSize), - dataPageV2.getUncompressedSize(), - firstRowIndex, - dataPageV2.getStatistics(), - false); + DataPageV2 dataPageV2 = (DataPageV2) compressedPage; + if (!dataPageV2.isCompressed()) { + return dataPageV2; } + int uncompressedSize = dataPageV2.getUncompressedSize() + - dataPageV2.getDefinitionLevels().length() + - dataPageV2.getRepetitionLevels().length(); + return new DataPageV2( + dataPageV2.getRowCount(), + dataPageV2.getNullCount(), + dataPageV2.getValueCount(), + dataPageV2.getRepetitionLevels(), + dataPageV2.getDefinitionLevels(), + dataPageV2.getDataEncoding(), + decompress(codec, dataPageV2.getSlice(), uncompressedSize), + dataPageV2.getUncompressedSize(), + firstRowIndex, + dataPageV2.getStatistics(), + false); } catch (IOException e) { throw new RuntimeException("Could not decompress page", e); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunk.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunk.java index 14e25b75df2b..cc5a2fa1ced7 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunk.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetColumnChunk.java @@ -123,9 +123,7 @@ private boolean hasMorePages(long valuesCountReadSoFar, int dataPageCountReadSoF if (offsetIndex == null) { return valuesCountReadSoFar < descriptor.getColumnChunkMetaData().getValueCount(); } - else { - return dataPageCountReadSoFar < offsetIndex.getPageCount(); - } + return dataPageCountReadSoFar < offsetIndex.getPageCount(); } private Slice getSlice(int size) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index af5787402641..582fd4ca8d6e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -76,6 +76,7 @@ import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public class ParquetReader @@ -130,10 +131,11 @@ public ParquetReader( ParquetDataSource dataSource, DateTimeZone timeZone, AggregatedMemoryContext memoryContext, - ParquetReaderOptions options) + ParquetReaderOptions options, + Optional parquetPredicate) throws IOException { - this(fileCreatedBy, fields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, null, null); + this(fileCreatedBy, fields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, parquetPredicate, nCopies(blocks.size(), Optional.empty())); } public ParquetReader( @@ -145,7 +147,7 @@ public ParquetReader( DateTimeZone timeZone, AggregatedMemoryContext memoryContext, ParquetReaderOptions options, - Predicate parquetPredicate, + Optional parquetPredicate, List> columnIndexStore) throws IOException { @@ -164,14 +166,16 @@ public ParquetReader( checkArgument(blocks.size() == firstRowsOfBlocks.size(), "elements of firstRowsOfBlocks must correspond to blocks"); - this.columnIndexStore = columnIndexStore; this.blockRowRanges = listWithNulls(this.blocks.size()); for (PrimitiveField field : primitiveFields) { ColumnDescriptor columnDescriptor = field.getDescriptor(); this.paths.put(ColumnPath.get(columnDescriptor.getPath()), columnDescriptor); } - if (parquetPredicate != null && options.isUseColumnIndex()) { - this.filter = parquetPredicate.toParquetFilter(timeZone); + + requireNonNull(parquetPredicate, "parquetPredicate is null"); + this.columnIndexStore = requireNonNull(columnIndexStore, "columnIndexStore is null"); + if (parquetPredicate.isPresent() && options.isUseColumnIndex()) { + this.filter = parquetPredicate.get().toParquetFilter(timeZone); } else { this.filter = Optional.empty(); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java index a46263d2f575..2634741980f8 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java @@ -14,8 +14,8 @@ package io.trino.parquet.writer; import com.google.common.collect.ImmutableList; -import io.trino.parquet.writer.repdef.DefLevelIterable; -import io.trino.parquet.writer.repdef.DefLevelIterables; +import io.trino.parquet.writer.repdef.DefLevelWriterProvider; +import io.trino.parquet.writer.repdef.DefLevelWriterProviders; import io.trino.parquet.writer.repdef.RepLevelIterable; import io.trino.parquet.writer.repdef.RepLevelIterables; import io.trino.spi.block.ColumnarArray; @@ -49,9 +49,9 @@ public void writeBlock(ColumnChunk columnChunk) ColumnarArray columnarArray = ColumnarArray.toColumnarArray(columnChunk.getBlock()); elementWriter.writeBlock( new ColumnChunk(columnarArray.getElementsBlock(), - ImmutableList.builder() - .addAll(columnChunk.getDefLevelIterables()) - .add(DefLevelIterables.of(columnarArray, maxDefinitionLevel)) + ImmutableList.builder() + .addAll(columnChunk.getDefLevelWriterProviders()) + .add(DefLevelWriterProviders.of(columnarArray, maxDefinitionLevel)) .build(), ImmutableList.builder() .addAll(columnChunk.getRepLevelIterables()) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnChunk.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnChunk.java index 463e68028864..186e5e77cba7 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnChunk.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnChunk.java @@ -14,7 +14,7 @@ package io.trino.parquet.writer; import com.google.common.collect.ImmutableList; -import io.trino.parquet.writer.repdef.DefLevelIterable; +import io.trino.parquet.writer.repdef.DefLevelWriterProvider; import io.trino.parquet.writer.repdef.RepLevelIterable; import io.trino.spi.block.Block; @@ -25,7 +25,7 @@ public class ColumnChunk { private final Block block; - private final List defLevelIterables; + private final List defLevelWriterProviders; private final List repLevelIterables; ColumnChunk(Block block) @@ -33,16 +33,16 @@ public class ColumnChunk this(block, ImmutableList.of(), ImmutableList.of()); } - ColumnChunk(Block block, List defLevelIterables, List repLevelIterables) + ColumnChunk(Block block, List defLevelWriterProviders, List repLevelIterables) { this.block = requireNonNull(block, "block is null"); - this.defLevelIterables = ImmutableList.copyOf(defLevelIterables); + this.defLevelWriterProviders = ImmutableList.copyOf(defLevelWriterProviders); this.repLevelIterables = ImmutableList.copyOf(repLevelIterables); } - List getDefLevelIterables() + List getDefLevelWriterProviders() { - return defLevelIterables; + return defLevelWriterProviders; } List getRepLevelIterables() diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java index 8102259d9218..fa8176f57717 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java @@ -14,8 +14,8 @@ package io.trino.parquet.writer; import com.google.common.collect.ImmutableList; -import io.trino.parquet.writer.repdef.DefLevelIterable; -import io.trino.parquet.writer.repdef.DefLevelIterables; +import io.trino.parquet.writer.repdef.DefLevelWriterProvider; +import io.trino.parquet.writer.repdef.DefLevelWriterProviders; import io.trino.parquet.writer.repdef.RepLevelIterable; import io.trino.parquet.writer.repdef.RepLevelIterables; import io.trino.spi.block.ColumnarMap; @@ -50,16 +50,16 @@ public void writeBlock(ColumnChunk columnChunk) { ColumnarMap columnarMap = ColumnarMap.toColumnarMap(columnChunk.getBlock()); - ImmutableList defLevelIterables = ImmutableList.builder() - .addAll(columnChunk.getDefLevelIterables()) - .add(DefLevelIterables.of(columnarMap, maxDefinitionLevel)).build(); + ImmutableList defLevelWriterProviders = ImmutableList.builder() + .addAll(columnChunk.getDefLevelWriterProviders()) + .add(DefLevelWriterProviders.of(columnarMap, maxDefinitionLevel)).build(); ImmutableList repLevelIterables = ImmutableList.builder() .addAll(columnChunk.getRepLevelIterables()) .add(RepLevelIterables.of(columnarMap, maxRepetitionLevel)).build(); - keyWriter.writeBlock(new ColumnChunk(columnarMap.getKeysBlock(), defLevelIterables, repLevelIterables)); - valueWriter.writeBlock(new ColumnChunk(columnarMap.getValuesBlock(), defLevelIterables, repLevelIterables)); + keyWriter.writeBlock(new ColumnChunk(columnarMap.getKeysBlock(), defLevelWriterProviders, repLevelIterables)); + valueWriter.writeBlock(new ColumnChunk(columnarMap.getValuesBlock(), defLevelWriterProviders, repLevelIterables)); } @Override diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java index c3a51f8a71ed..ea02339c21e0 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java @@ -100,15 +100,13 @@ private org.apache.parquet.schema.Type convert(Type type, String name, List parent, Repetition repetition) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java index 72b6153f688d..d494770fe64a 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java @@ -36,77 +36,75 @@ public static T visit(Type type, ParquetTypeVisitor visitor) if (type instanceof MessageType) { return visitor.message((MessageType) type, visitFields(type.asGroupType(), visitor)); } - else if (type.isPrimitive()) { + if (type.isPrimitive()) { return visitor.primitive(type.asPrimitiveType()); } - else { - // if not a primitive, the typeId must be a group - GroupType group = type.asGroupType(); - LogicalTypeAnnotation annotation = group.getLogicalTypeAnnotation(); - if (LogicalTypeAnnotation.listType().equals(annotation)) { - checkArgument(!group.isRepetition(REPEATED), - "Invalid list: top-level group is repeated: " + group); - checkArgument(group.getFieldCount() == 1, - "Invalid list: does not contain single repeated field: " + group); - - GroupType repeatedElement = group.getFields().get(0).asGroupType(); - checkArgument(repeatedElement.isRepetition(REPEATED), - "Invalid list: inner group is not repeated"); - checkArgument(repeatedElement.getFieldCount() <= 1, - "Invalid list: repeated group is not a single field: " + group); - - visitor.fieldNames.push(repeatedElement.getName()); - try { - T elementResult = null; - if (repeatedElement.getFieldCount() > 0) { - elementResult = visitField(repeatedElement.getType(0), visitor); - } - - return visitor.list(group, elementResult); - } - finally { - visitor.fieldNames.pop(); + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + LogicalTypeAnnotation annotation = group.getLogicalTypeAnnotation(); + if (LogicalTypeAnnotation.listType().equals(annotation)) { + checkArgument(!group.isRepetition(REPEATED), + "Invalid list: top-level group is repeated: " + group); + checkArgument(group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: " + group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + checkArgument(repeatedElement.isRepetition(REPEATED), + "Invalid list: inner group is not repeated"); + checkArgument(repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: " + group); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(repeatedElement.getType(0), visitor); } + + return visitor.list(group, elementResult); + } + finally { + visitor.fieldNames.pop(); } - else if (LogicalTypeAnnotation.mapType().equals(annotation)) { - checkArgument(!group.isRepetition(REPEATED), - "Invalid map: top-level group is repeated: " + group); - checkArgument(group.getFieldCount() == 1, - "Invalid map: does not contain single repeated field: " + group); - - GroupType repeatedKeyValue = group.getType(0).asGroupType(); - checkArgument(repeatedKeyValue.isRepetition(REPEATED), - "Invalid map: inner group is not repeated"); - checkArgument(repeatedKeyValue.getFieldCount() <= 2, - "Invalid map: repeated group does not have 2 fields"); - - visitor.fieldNames.push(repeatedKeyValue.getName()); - try { - T keyResult = null; - T valueResult = null; - if (repeatedKeyValue.getFieldCount() == 2) { - keyResult = visitField(repeatedKeyValue.getType(0), visitor); - valueResult = visitField(repeatedKeyValue.getType(1), visitor); + } + if (LogicalTypeAnnotation.mapType().equals(annotation)) { + checkArgument(!group.isRepetition(REPEATED), + "Invalid map: top-level group is repeated: " + group); + checkArgument(group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: " + group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + checkArgument(repeatedKeyValue.isRepetition(REPEATED), + "Invalid map: inner group is not repeated"); + checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + if (repeatedKeyValue.getFieldCount() == 2) { + keyResult = visitField(repeatedKeyValue.getType(0), visitor); + valueResult = visitField(repeatedKeyValue.getType(1), visitor); + } + else if (repeatedKeyValue.getFieldCount() == 1) { + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyOrValue, visitor); + // value result remains null } - else if (repeatedKeyValue.getFieldCount() == 1) { - Type keyOrValue = repeatedKeyValue.getType(0); - if (keyOrValue.getName().equalsIgnoreCase("key")) { - keyResult = visitField(keyOrValue, visitor); - // value result remains null - } - else { - valueResult = visitField(keyOrValue, visitor); - // key result remains null - } + else { + valueResult = visitField(keyOrValue, visitor); + // key result remains null } - return visitor.map(group, keyResult, valueResult); - } - finally { - visitor.fieldNames.pop(); } + return visitor.map(group, keyResult, valueResult); + } + finally { + visitor.fieldNames.pop(); } - return visitor.struct(group, visitFields(group, visitor)); } + return visitor.struct(group, visitFields(group, visitor)); } private static T visitField(Type field, ParquetTypeVisitor visitor) diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java index b4581fb19e21..c895cb92b73f 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java @@ -15,12 +15,11 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; -import io.trino.parquet.writer.repdef.DefLevelIterable; -import io.trino.parquet.writer.repdef.DefLevelIterables; +import io.trino.parquet.writer.repdef.DefLevelWriterProvider; +import io.trino.parquet.writer.repdef.DefLevelWriterProviders; import io.trino.parquet.writer.repdef.RepLevelIterable; import io.trino.parquet.writer.repdef.RepLevelIterables; import io.trino.parquet.writer.valuewriter.PrimitiveValueWriter; -import io.trino.spi.block.Block; import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; @@ -50,6 +49,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.parquet.writer.ParquetCompressor.getCompressor; import static io.trino.parquet.writer.ParquetDataOutput.createDataOutput; +import static io.trino.parquet.writer.repdef.DefLevelWriterProvider.DefinitionLevelWriter; +import static io.trino.parquet.writer.repdef.DefLevelWriterProvider.getRootDefinitionLevelWriter; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -116,38 +117,15 @@ public void writeBlock(ColumnChunk columnChunk) // write values primitiveValueWriter.write(columnChunk.getBlock()); - if (columnChunk.getDefLevelIterables().isEmpty()) { - // write definition levels for flat data types - Block block = columnChunk.getBlock(); - if (!block.mayHaveNull()) { - for (int position = 0; position < block.getPositionCount(); position++) { - definitionLevelWriter.writeInteger(maxDefinitionLevel); - } - } - else { - for (int position = 0; position < block.getPositionCount(); position++) { - byte isNull = (byte) (block.isNull(position) ? 1 : 0); - definitionLevelWriter.writeInteger(maxDefinitionLevel - isNull); - currentPageNullCounts += isNull; - } - } - valueCount += block.getPositionCount(); - } - else { - // write definition levels for nested data types - Iterator defIterator = DefLevelIterables.getIterator(ImmutableList.builder() - .addAll(columnChunk.getDefLevelIterables()) - .add(DefLevelIterables.of(columnChunk.getBlock(), maxDefinitionLevel)) - .build()); - while (defIterator.hasNext()) { - int next = defIterator.next(); - definitionLevelWriter.writeInteger(next); - if (next != maxDefinitionLevel) { - currentPageNullCounts++; - } - valueCount++; - } - } + List defLevelWriterProviders = ImmutableList.builder() + .addAll(columnChunk.getDefLevelWriterProviders()) + .add(DefLevelWriterProviders.of(columnChunk.getBlock(), maxDefinitionLevel)) + .build(); + DefinitionLevelWriter rootDefinitionLevelWriter = getRootDefinitionLevelWriter(defLevelWriterProviders, definitionLevelWriter); + + DefLevelWriterProvider.ValuesCount valuesCount = rootDefinitionLevelWriter.writeDefinitionLevels(); + currentPageNullCounts += valuesCount.totalValuesCount() - valuesCount.maxDefinitionLevelValuesCount(); + valueCount += valuesCount.totalValuesCount(); if (columnDescriptor.getMaxRepetitionLevel() > 0) { // write repetition levels for nested types diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java index ffcc08a2f921..524e282063f3 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java @@ -14,8 +14,8 @@ package io.trino.parquet.writer; import com.google.common.collect.ImmutableList; -import io.trino.parquet.writer.repdef.DefLevelIterable; -import io.trino.parquet.writer.repdef.DefLevelIterables; +import io.trino.parquet.writer.repdef.DefLevelWriterProvider; +import io.trino.parquet.writer.repdef.DefLevelWriterProviders; import io.trino.parquet.writer.repdef.RepLevelIterable; import io.trino.parquet.writer.repdef.RepLevelIterables; import io.trino.spi.block.Block; @@ -50,9 +50,9 @@ public void writeBlock(ColumnChunk columnChunk) ColumnarRow columnarRow = toColumnarRow(columnChunk.getBlock()); checkArgument(columnarRow.getFieldCount() == columnWriters.size(), "ColumnarRow field size %s is not equal to columnWriters size %s", columnarRow.getFieldCount(), columnWriters.size()); - List defLevelIterables = ImmutableList.builder() - .addAll(columnChunk.getDefLevelIterables()) - .add(DefLevelIterables.of(columnarRow, maxDefinitionLevel)) + List defLevelWriterProviders = ImmutableList.builder() + .addAll(columnChunk.getDefLevelWriterProviders()) + .add(DefLevelWriterProviders.of(columnarRow, maxDefinitionLevel)) .build(); List repLevelIterables = ImmutableList.builder() .addAll(columnChunk.getRepLevelIterables()) @@ -62,7 +62,7 @@ public void writeBlock(ColumnChunk columnChunk) for (int i = 0; i < columnWriters.size(); ++i) { ColumnWriter columnWriter = columnWriters.get(i); Block block = columnarRow.getField(i); - columnWriter.writeBlock(new ColumnChunk(block, defLevelIterables, repLevelIterables)); + columnWriter.writeBlock(new ColumnChunk(block, defLevelWriterProviders, repLevelIterables)); } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelIterables.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelIterables.java deleted file mode 100644 index 3710a43c9307..000000000000 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelIterables.java +++ /dev/null @@ -1,275 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.parquet.writer.repdef; - -import com.google.common.collect.AbstractIterator; -import io.trino.parquet.writer.repdef.DefLevelIterable.DefLevelIterator; -import io.trino.spi.block.Block; -import io.trino.spi.block.ColumnarArray; -import io.trino.spi.block.ColumnarMap; -import io.trino.spi.block.ColumnarRow; - -import java.util.Iterator; -import java.util.List; -import java.util.OptionalInt; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.Collections.nCopies; -import static java.util.Objects.requireNonNull; - -public class DefLevelIterables -{ - private DefLevelIterables() {} - - public static DefLevelIterable of(Block block, int maxDefinitionLevel) - { - return new PrimitiveDefLevelIterable(block, maxDefinitionLevel); - } - - public static DefLevelIterable of(ColumnarRow columnarRow, int maxDefinitionLevel) - { - return new ColumnRowDefLevelIterable(columnarRow, maxDefinitionLevel); - } - - public static DefLevelIterable of(ColumnarArray columnarArray, int maxDefinitionLevel) - { - return new ColumnArrayDefLevelIterable(columnarArray, maxDefinitionLevel); - } - - public static DefLevelIterable of(ColumnarMap columnarMap, int maxDefinitionLevel) - { - return new ColumnMapDefLevelIterable(columnarMap, maxDefinitionLevel); - } - - public static Iterator getIterator(List iterables) - { - return new NestedDefLevelIterator(iterables); - } - - static class PrimitiveDefLevelIterable - implements DefLevelIterable - { - private final Block block; - private final int maxDefinitionLevel; - - PrimitiveDefLevelIterable(Block block, int maxDefinitionLevel) - { - this.block = requireNonNull(block, "block is null"); - this.maxDefinitionLevel = maxDefinitionLevel; - } - - @Override - public DefLevelIterator getIterator() - { - return new DefLevelIterator() - { - private int position = -1; - - @Override - boolean end() - { - return true; - } - - @Override - protected OptionalInt computeNext() - { - position++; - if (position == block.getPositionCount()) { - return endOfData(); - } - if (block.isNull(position)) { - return OptionalInt.of(maxDefinitionLevel - 1); - } - return OptionalInt.of(maxDefinitionLevel); - } - }; - } - } - - static class ColumnRowDefLevelIterable - implements DefLevelIterable - { - private final ColumnarRow columnarRow; - private final int maxDefinitionLevel; - - ColumnRowDefLevelIterable(ColumnarRow columnarRow, int maxDefinitionLevel) - { - this.columnarRow = requireNonNull(columnarRow, "columnarRow is null"); - this.maxDefinitionLevel = maxDefinitionLevel; - } - - @Override - public DefLevelIterator getIterator() - { - return new DefLevelIterator() - { - private int position = -1; - - @Override - boolean end() - { - return true; - } - - @Override - protected OptionalInt computeNext() - { - position++; - if (position == columnarRow.getPositionCount()) { - return endOfData(); - } - if (columnarRow.isNull(position)) { - return OptionalInt.of(maxDefinitionLevel - 1); - } - return OptionalInt.empty(); - } - }; - } - } - - static class ColumnMapDefLevelIterable - implements DefLevelIterable - { - private final ColumnarMap columnarMap; - private final int maxDefinitionLevel; - - ColumnMapDefLevelIterable(ColumnarMap columnarMap, int maxDefinitionLevel) - { - this.columnarMap = requireNonNull(columnarMap, "columnarMap is null"); - this.maxDefinitionLevel = maxDefinitionLevel; - } - - @Override - public DefLevelIterator getIterator() - { - return new DefLevelIterator() - { - private int position = -1; - private Iterator iterator; - - @Override - boolean end() - { - return iterator == null || !iterator.hasNext(); - } - - @Override - protected OptionalInt computeNext() - { - if (iterator != null && iterator.hasNext()) { - return iterator.next(); - } - position++; - if (position == columnarMap.getPositionCount()) { - return endOfData(); - } - if (columnarMap.isNull(position)) { - return OptionalInt.of(maxDefinitionLevel - 2); - } - int arrayLength = columnarMap.getEntryCount(position); - if (arrayLength == 0) { - return OptionalInt.of(maxDefinitionLevel - 1); - } - iterator = nCopies(arrayLength, OptionalInt.empty()).iterator(); - return iterator.next(); - } - }; - } - } - - static class ColumnArrayDefLevelIterable - implements DefLevelIterable - { - private final ColumnarArray columnarArray; - private final int maxDefinitionLevel; - - ColumnArrayDefLevelIterable(ColumnarArray columnarArray, int maxDefinitionLevel) - { - this.columnarArray = requireNonNull(columnarArray, "columnarArray is null"); - this.maxDefinitionLevel = maxDefinitionLevel; - } - - @Override - public DefLevelIterator getIterator() - { - return new DefLevelIterator() - { - private int position = -1; - private Iterator iterator; - - @Override - boolean end() - { - return iterator == null || !iterator.hasNext(); - } - - @Override - protected OptionalInt computeNext() - { - if (iterator != null && iterator.hasNext()) { - return iterator.next(); - } - position++; - if (position == columnarArray.getPositionCount()) { - return endOfData(); - } - if (columnarArray.isNull(position)) { - return OptionalInt.of(maxDefinitionLevel - 2); - } - int arrayLength = columnarArray.getLength(position); - if (arrayLength == 0) { - return OptionalInt.of(maxDefinitionLevel - 1); - } - iterator = nCopies(arrayLength, OptionalInt.empty()).iterator(); - return iterator.next(); - } - }; - } - } - - static class NestedDefLevelIterator - extends AbstractIterator - { - private final List iterators; - private int iteratorIndex; - - NestedDefLevelIterator(List iterables) - { - this.iterators = iterables.stream().map(DefLevelIterable::getIterator).collect(toImmutableList()); - } - - @Override - protected Integer computeNext() - { - DefLevelIterator current = iterators.get(iteratorIndex); - while (iteratorIndex > 0 && current.end()) { - iteratorIndex--; - current = iterators.get(iteratorIndex); - } - - while (current.hasNext()) { - OptionalInt next = current.next(); - if (next.isPresent()) { - return next.getAsInt(); - } - iteratorIndex++; - current = iterators.get(iteratorIndex); - } - checkState(iterators.stream().noneMatch(AbstractIterator::hasNext)); - return endOfData(); - } - } -} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProvider.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProvider.java new file mode 100644 index 000000000000..8b3bff8e8c97 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProvider.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer.repdef; + +import com.google.common.collect.Iterables; +import org.apache.parquet.column.values.ValuesWriter; + +import java.util.List; +import java.util.Optional; + +public interface DefLevelWriterProvider +{ + DefinitionLevelWriter getDefinitionLevelWriter(Optional nestedWriter, ValuesWriter encoder); + + interface DefinitionLevelWriter + { + ValuesCount writeDefinitionLevels(int positionsCount); + + ValuesCount writeDefinitionLevels(); + } + + record ValuesCount(int totalValuesCount, int maxDefinitionLevelValuesCount) + { + } + + static DefinitionLevelWriter getRootDefinitionLevelWriter(List defLevelWriterProviders, ValuesWriter encoder) + { + // Constructs hierarchy of DefinitionLevelWriter from leaf to root + DefinitionLevelWriter rootDefinitionLevelWriter = Iterables.getLast(defLevelWriterProviders) + .getDefinitionLevelWriter(Optional.empty(), encoder); + for (int nestedLevel = defLevelWriterProviders.size() - 2; nestedLevel >= 0; nestedLevel--) { + DefinitionLevelWriter nestedWriter = rootDefinitionLevelWriter; + rootDefinitionLevelWriter = defLevelWriterProviders.get(nestedLevel) + .getDefinitionLevelWriter(Optional.of(nestedWriter), encoder); + } + return rootDefinitionLevelWriter; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java new file mode 100644 index 000000000000..7de391070c78 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java @@ -0,0 +1,342 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer.repdef; + +import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarArray; +import io.trino.spi.block.ColumnarMap; +import io.trino.spi.block.ColumnarRow; +import org.apache.parquet.column.values.ValuesWriter; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class DefLevelWriterProviders +{ + private DefLevelWriterProviders() {} + + public static DefLevelWriterProvider of(Block block, int maxDefinitionLevel) + { + return new PrimitiveDefLevelWriterProvider(block, maxDefinitionLevel); + } + + public static DefLevelWriterProvider of(ColumnarRow columnarRow, int maxDefinitionLevel) + { + return new ColumnRowDefLevelWriterProvider(columnarRow, maxDefinitionLevel); + } + + public static DefLevelWriterProvider of(ColumnarArray columnarArray, int maxDefinitionLevel) + { + return new ColumnArrayDefLevelWriterProvider(columnarArray, maxDefinitionLevel); + } + + public static DefLevelWriterProvider of(ColumnarMap columnarMap, int maxDefinitionLevel) + { + return new ColumnMapDefLevelWriterProvider(columnarMap, maxDefinitionLevel); + } + + static class PrimitiveDefLevelWriterProvider + implements DefLevelWriterProvider + { + private final Block block; + private final int maxDefinitionLevel; + + PrimitiveDefLevelWriterProvider(Block block, int maxDefinitionLevel) + { + this.block = requireNonNull(block, "block is null"); + this.maxDefinitionLevel = maxDefinitionLevel; + } + + @Override + public DefinitionLevelWriter getDefinitionLevelWriter(Optional nestedWriter, ValuesWriter encoder) + { + checkArgument(nestedWriter.isEmpty(), "nestedWriter should be empty for primitive definition level writer"); + return new DefinitionLevelWriter() + { + private int offset; + + @Override + public ValuesCount writeDefinitionLevels() + { + return writeDefinitionLevels(block.getPositionCount()); + } + + @Override + public ValuesCount writeDefinitionLevels(int positionsCount) + { + checkValidPosition(offset, positionsCount, block.getPositionCount()); + int nonNullsCount = 0; + if (!block.mayHaveNull()) { + for (int position = offset; position < offset + positionsCount; position++) { + encoder.writeInteger(maxDefinitionLevel); + } + nonNullsCount = positionsCount; + } + else { + for (int position = offset; position < offset + positionsCount; position++) { + int isNull = block.isNull(position) ? 1 : 0; + encoder.writeInteger(maxDefinitionLevel - isNull); + nonNullsCount += isNull ^ 1; + } + } + offset += positionsCount; + return new ValuesCount(positionsCount, nonNullsCount); + } + }; + } + } + + static class ColumnRowDefLevelWriterProvider + implements DefLevelWriterProvider + { + private final ColumnarRow columnarRow; + private final int maxDefinitionLevel; + + ColumnRowDefLevelWriterProvider(ColumnarRow columnarRow, int maxDefinitionLevel) + { + this.columnarRow = requireNonNull(columnarRow, "columnarRow is null"); + this.maxDefinitionLevel = maxDefinitionLevel; + } + + @Override + public DefinitionLevelWriter getDefinitionLevelWriter(Optional nestedWriterOptional, ValuesWriter encoder) + { + checkArgument(nestedWriterOptional.isPresent(), "nestedWriter should be present for column row definition level writer"); + return new DefinitionLevelWriter() + { + private final DefinitionLevelWriter nestedWriter = nestedWriterOptional.orElseThrow(); + + private int offset; + + @Override + public ValuesCount writeDefinitionLevels() + { + return writeDefinitionLevels(columnarRow.getPositionCount()); + } + + @Override + public ValuesCount writeDefinitionLevels(int positionsCount) + { + checkValidPosition(offset, positionsCount, columnarRow.getPositionCount()); + if (!columnarRow.mayHaveNull()) { + offset += positionsCount; + return nestedWriter.writeDefinitionLevels(positionsCount); + } + int maxDefinitionValuesCount = 0; + int totalValuesCount = 0; + for (int position = offset; position < offset + positionsCount; ) { + if (columnarRow.isNull(position)) { + encoder.writeInteger(maxDefinitionLevel - 1); + totalValuesCount++; + position++; + } + else { + int consecutiveNonNullsCount = 1; + position++; + while (position < offset + positionsCount && !columnarRow.isNull(position)) { + position++; + consecutiveNonNullsCount++; + } + ValuesCount valuesCount = nestedWriter.writeDefinitionLevels(consecutiveNonNullsCount); + maxDefinitionValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + totalValuesCount += valuesCount.totalValuesCount(); + } + } + offset += positionsCount; + return new ValuesCount(totalValuesCount, maxDefinitionValuesCount); + } + }; + } + } + + static class ColumnMapDefLevelWriterProvider + implements DefLevelWriterProvider + { + private final ColumnarMap columnarMap; + private final int maxDefinitionLevel; + + ColumnMapDefLevelWriterProvider(ColumnarMap columnarMap, int maxDefinitionLevel) + { + this.columnarMap = requireNonNull(columnarMap, "columnarMap is null"); + this.maxDefinitionLevel = maxDefinitionLevel; + } + + @Override + public DefinitionLevelWriter getDefinitionLevelWriter(Optional nestedWriterOptional, ValuesWriter encoder) + { + checkArgument(nestedWriterOptional.isPresent(), "nestedWriter should be present for column map definition level writer"); + return new DefinitionLevelWriter() + { + private final DefinitionLevelWriter nestedWriter = nestedWriterOptional.orElseThrow(); + + private int offset; + + @Override + public ValuesCount writeDefinitionLevels() + { + return writeDefinitionLevels(columnarMap.getPositionCount()); + } + + @Override + public ValuesCount writeDefinitionLevels(int positionsCount) + { + checkValidPosition(offset, positionsCount, columnarMap.getPositionCount()); + int maxDefinitionValuesCount = 0; + int totalValuesCount = 0; + if (!columnarMap.mayHaveNull()) { + for (int position = offset; position < offset + positionsCount; ) { + int mapLength = columnarMap.getEntryCount(position); + if (mapLength == 0) { + encoder.writeInteger(maxDefinitionLevel - 1); + totalValuesCount++; + position++; + } + else { + int consecutiveNonEmptyArrayLength = mapLength; + position++; + while (position < offset + positionsCount) { + mapLength = columnarMap.getEntryCount(position); + if (mapLength == 0) { + break; + } + position++; + consecutiveNonEmptyArrayLength += mapLength; + } + ValuesCount valuesCount = nestedWriter.writeDefinitionLevels(consecutiveNonEmptyArrayLength); + maxDefinitionValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + totalValuesCount += valuesCount.totalValuesCount(); + } + } + } + else { + for (int position = offset; position < offset + positionsCount; position++) { + if (columnarMap.isNull(position)) { + encoder.writeInteger(maxDefinitionLevel - 2); + totalValuesCount++; + continue; + } + int mapLength = columnarMap.getEntryCount(position); + if (mapLength == 0) { + encoder.writeInteger(maxDefinitionLevel - 1); + totalValuesCount++; + } + else { + ValuesCount valuesCount = nestedWriter.writeDefinitionLevels(mapLength); + maxDefinitionValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + totalValuesCount += valuesCount.totalValuesCount(); + } + } + } + offset += positionsCount; + return new ValuesCount(totalValuesCount, maxDefinitionValuesCount); + } + }; + } + } + + static class ColumnArrayDefLevelWriterProvider + implements DefLevelWriterProvider + { + private final ColumnarArray columnarArray; + private final int maxDefinitionLevel; + + ColumnArrayDefLevelWriterProvider(ColumnarArray columnarArray, int maxDefinitionLevel) + { + this.columnarArray = requireNonNull(columnarArray, "columnarArray is null"); + this.maxDefinitionLevel = maxDefinitionLevel; + } + + @Override + public DefinitionLevelWriter getDefinitionLevelWriter(Optional nestedWriterOptional, ValuesWriter encoder) + { + checkArgument(nestedWriterOptional.isPresent(), "nestedWriter should be present for column map definition level writer"); + return new DefinitionLevelWriter() + { + private final DefinitionLevelWriter nestedWriter = nestedWriterOptional.orElseThrow(); + + private int offset; + + @Override + public ValuesCount writeDefinitionLevels() + { + return writeDefinitionLevels(columnarArray.getPositionCount()); + } + + @Override + public ValuesCount writeDefinitionLevels(int positionsCount) + { + checkValidPosition(offset, positionsCount, columnarArray.getPositionCount()); + int maxDefinitionValuesCount = 0; + int totalValuesCount = 0; + if (!columnarArray.mayHaveNull()) { + for (int position = offset; position < offset + positionsCount; ) { + int arrayLength = columnarArray.getLength(position); + if (arrayLength == 0) { + encoder.writeInteger(maxDefinitionLevel - 1); + totalValuesCount++; + position++; + } + else { + int consecutiveNonEmptyArrayLength = arrayLength; + position++; + while (position < offset + positionsCount) { + arrayLength = columnarArray.getLength(position); + if (arrayLength == 0) { + break; + } + position++; + consecutiveNonEmptyArrayLength += arrayLength; + } + ValuesCount valuesCount = nestedWriter.writeDefinitionLevels(consecutiveNonEmptyArrayLength); + maxDefinitionValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + totalValuesCount += valuesCount.totalValuesCount(); + } + } + } + else { + for (int position = offset; position < offset + positionsCount; position++) { + if (columnarArray.isNull(position)) { + encoder.writeInteger(maxDefinitionLevel - 2); + totalValuesCount++; + continue; + } + int arrayLength = columnarArray.getLength(position); + if (arrayLength == 0) { + encoder.writeInteger(maxDefinitionLevel - 1); + totalValuesCount++; + } + else { + ValuesCount valuesCount = nestedWriter.writeDefinitionLevels(arrayLength); + maxDefinitionValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + totalValuesCount += valuesCount.totalValuesCount(); + } + } + } + offset += positionsCount; + return new ValuesCount(totalValuesCount, maxDefinitionValuesCount); + } + }; + } + } + + private static void checkValidPosition(int offset, int positionsCount, int totalPositionsCount) + { + if (offset < 0 || positionsCount < 0 || offset + positionsCount > totalPositionsCount) { + throw new IndexOutOfBoundsException(format("Invalid offset %s and positionsCount %s in block with %s positions", offset, positionsCount, totalPositionsCount)); + } + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestDefinitionLevelWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestDefinitionLevelWriter.java new file mode 100644 index 000000000000..8391a5ed6d99 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestDefinitionLevelWriter.java @@ -0,0 +1,738 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer; + +import com.google.common.collect.ImmutableList; +import io.trino.parquet.writer.repdef.DefLevelWriterProviders; +import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarArray; +import io.trino.spi.block.ColumnarMap; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.type.MapType; +import io.trino.spi.type.TypeOperators; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntList; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.values.ValuesWriter; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.stream.Stream; + +import static io.trino.parquet.writer.repdef.DefLevelWriterProvider.DefinitionLevelWriter; +import static io.trino.parquet.writer.repdef.DefLevelWriterProvider.ValuesCount; +import static io.trino.parquet.writer.repdef.DefLevelWriterProvider.getRootDefinitionLevelWriter; +import static io.trino.spi.block.ArrayBlock.fromElementBlock; +import static io.trino.spi.block.ColumnarArray.toColumnarArray; +import static io.trino.spi.block.ColumnarMap.toColumnarMap; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; +import static io.trino.spi.block.MapBlock.fromKeyValueBlock; +import static io.trino.spi.block.RowBlock.fromFieldBlocks; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.testing.DataProviders.toDataProvider; +import static java.lang.Math.toIntExact; +import static java.util.Collections.nCopies; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestDefinitionLevelWriter +{ + private static final int POSITIONS = 8096; + private static final Random RANDOM = new Random(42); + private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + + private static final boolean[] ALL_NULLS_ARRAY = new boolean[POSITIONS]; + private static final boolean[] RANDOM_NULLS_ARRAY = new boolean[POSITIONS]; + private static final boolean[] GROUPED_NULLS_ARRAY = new boolean[POSITIONS]; + + static { + Arrays.fill(ALL_NULLS_ARRAY, true); + for (int i = 0; i < POSITIONS; i++) { + RANDOM_NULLS_ARRAY[i] = RANDOM.nextBoolean(); + } + + int maxGroupSize = 23; + int position = 0; + while (position < POSITIONS) { + int remaining = POSITIONS - position; + int groupSize = Math.min(RANDOM.nextInt(maxGroupSize) + 1, remaining); + Arrays.fill(GROUPED_NULLS_ARRAY, position, position + groupSize, RANDOM.nextBoolean()); + position += groupSize; + } + } + + @Test(dataProvider = "primitiveBlockProvider") + public void testWritePrimitiveDefinitionLevels(PrimitiveBlockProvider blockProvider) + { + Block block = blockProvider.getInputBlock(); + int maxDefinitionLevel = 3; + // Write definition levels for all positions + assertDefinitionLevels(block, ImmutableList.of(), maxDefinitionLevel); + + // Write definition levels for all positions one-at-a-time + assertDefinitionLevels(block, nCopies(block.getPositionCount(), 1), maxDefinitionLevel); + + // Write definition levels for all positions with different group sizes + assertDefinitionLevels(block, generateGroupSizes(block.getPositionCount()), maxDefinitionLevel); + } + + @DataProvider + public static Object[][] primitiveBlockProvider() + { + return Stream.of(PrimitiveBlockProvider.values()) + .collect(toDataProvider()); + } + + private enum PrimitiveBlockProvider + { + NO_NULLS { + @Override + Block getInputBlock() + { + return new LongArrayBlock(POSITIONS, Optional.empty(), new long[POSITIONS]); + } + }, + NO_NULLS_WITH_MAY_HAVE_NULL { + @Override + Block getInputBlock() + { + return new LongArrayBlock(POSITIONS, Optional.of(new boolean[POSITIONS]), new long[POSITIONS]); + } + }, + ALL_NULLS { + @Override + Block getInputBlock() + { + return new LongArrayBlock(POSITIONS, Optional.of(ALL_NULLS_ARRAY), new long[POSITIONS]); + } + }, + RANDOM_NULLS { + @Override + Block getInputBlock() + { + return new LongArrayBlock(POSITIONS, Optional.of(RANDOM_NULLS_ARRAY), new long[POSITIONS]); + } + }, + GROUPED_NULLS { + @Override + Block getInputBlock() + { + return new LongArrayBlock(POSITIONS, Optional.of(GROUPED_NULLS_ARRAY), new long[POSITIONS]); + } + }; + + abstract Block getInputBlock(); + } + + @Test(dataProvider = "rowBlockProvider") + public void testWriteRowDefinitionLevels(RowBlockProvider blockProvider) + { + ColumnarRow columnarRow = toColumnarRow(blockProvider.getInputBlock()); + int fieldMaxDefinitionLevel = 2; + // Write definition levels for all positions + for (int field = 0; field < columnarRow.getFieldCount(); field++) { + assertDefinitionLevels(columnarRow, ImmutableList.of(), field, fieldMaxDefinitionLevel); + } + + // Write definition levels for all positions one-at-a-time + for (int field = 0; field < columnarRow.getFieldCount(); field++) { + assertDefinitionLevels( + columnarRow, + nCopies(columnarRow.getPositionCount(), 1), + field, + fieldMaxDefinitionLevel); + } + + // Write definition levels for all positions with different group sizes + for (int field = 0; field < columnarRow.getFieldCount(); field++) { + assertDefinitionLevels( + columnarRow, + generateGroupSizes(columnarRow.getPositionCount()), + field, + fieldMaxDefinitionLevel); + } + } + + @DataProvider + public static Object[][] rowBlockProvider() + { + return Stream.of(RowBlockProvider.values()) + .collect(toDataProvider()); + } + + private enum RowBlockProvider + { + NO_NULLS { + @Override + Block getInputBlock() + { + return createRowBlock(Optional.empty()); + } + }, + NO_NULLS_WITH_MAY_HAVE_NULL { + @Override + Block getInputBlock() + { + return createRowBlock(Optional.of(new boolean[POSITIONS])); + } + }, + ALL_NULLS { + @Override + Block getInputBlock() + { + return createRowBlock(Optional.of(ALL_NULLS_ARRAY)); + } + }, + RANDOM_NULLS { + @Override + Block getInputBlock() + { + return createRowBlock(Optional.of(RANDOM_NULLS_ARRAY)); + } + }, + GROUPED_NULLS { + @Override + Block getInputBlock() + { + return createRowBlock(Optional.of(GROUPED_NULLS_ARRAY)); + } + }; + + abstract Block getInputBlock(); + + private static Block createRowBlock(Optional rowIsNull) + { + int positionCount = rowIsNull.map(isNull -> isNull.length).orElse(0) - toIntExact(rowIsNull.stream().count()); + int fieldCount = 4; + Block[] fieldBlocks = new Block[fieldCount]; + // no nulls block + fieldBlocks[0] = new LongArrayBlock(positionCount, Optional.empty(), new long[positionCount]); + // no nulls with mayHaveNull block + fieldBlocks[1] = new LongArrayBlock(positionCount, Optional.of(new boolean[positionCount]), new long[positionCount]); + // all nulls block + boolean[] allNulls = new boolean[positionCount]; + Arrays.fill(allNulls, false); + fieldBlocks[2] = new LongArrayBlock(positionCount, Optional.of(allNulls), new long[positionCount]); + // random nulls block + fieldBlocks[3] = createLongsBlockWithRandomNulls(positionCount); + + return fromFieldBlocks(positionCount, rowIsNull, fieldBlocks); + } + } + + @Test(dataProvider = "arrayBlockProvider") + public void testWriteArrayDefinitionLevels(ArrayBlockProvider blockProvider) + { + ColumnarArray columnarArray = toColumnarArray(blockProvider.getInputBlock()); + int maxDefinitionLevel = 3; + // Write definition levels for all positions + assertDefinitionLevels( + columnarArray, + ImmutableList.of(), + maxDefinitionLevel); + + // Write definition levels for all positions one-at-a-time + assertDefinitionLevels( + columnarArray, + nCopies(columnarArray.getPositionCount(), 1), + maxDefinitionLevel); + + // Write definition levels for all positions with different group sizes + assertDefinitionLevels( + columnarArray, + generateGroupSizes(columnarArray.getPositionCount()), + maxDefinitionLevel); + } + + @DataProvider + public static Object[][] arrayBlockProvider() + { + return Stream.of(ArrayBlockProvider.values()) + .collect(toDataProvider()); + } + + private enum ArrayBlockProvider + { + NO_NULLS { + @Override + Block getInputBlock() + { + return createArrayBlock(Optional.empty()); + } + }, + NO_NULLS_WITH_MAY_HAVE_NULL { + @Override + Block getInputBlock() + { + return createArrayBlock(Optional.of(new boolean[POSITIONS])); + } + }, + ALL_NULLS { + @Override + Block getInputBlock() + { + return createArrayBlock(Optional.of(ALL_NULLS_ARRAY)); + } + }, + RANDOM_NULLS { + @Override + Block getInputBlock() + { + return createArrayBlock(Optional.of(RANDOM_NULLS_ARRAY)); + } + }, + GROUPED_NULLS { + @Override + Block getInputBlock() + { + return createArrayBlock(Optional.of(GROUPED_NULLS_ARRAY)); + } + }; + + abstract Block getInputBlock(); + + private static Block createArrayBlock(Optional valueIsNull) + { + int[] arrayOffset = generateOffsets(valueIsNull); + return fromElementBlock(POSITIONS, valueIsNull, arrayOffset, createLongsBlockWithRandomNulls(arrayOffset[POSITIONS])); + } + } + + @Test(dataProvider = "mapBlockProvider") + public void testWriteMapDefinitionLevels(MapBlockProvider blockProvider) + { + ColumnarMap columnarMap = toColumnarMap(blockProvider.getInputBlock()); + int keysMaxDefinitionLevel = 2; + int valuesMaxDefinitionLevel = 3; + // Write definition levels for all positions + assertDefinitionLevels( + columnarMap, + ImmutableList.of(), + keysMaxDefinitionLevel, + valuesMaxDefinitionLevel); + + // Write definition levels for all positions one-at-a-time + assertDefinitionLevels( + columnarMap, + nCopies(columnarMap.getPositionCount(), 1), + keysMaxDefinitionLevel, + valuesMaxDefinitionLevel); + + // Write definition levels for all positions with different group sizes + assertDefinitionLevels( + columnarMap, + generateGroupSizes(columnarMap.getPositionCount()), + keysMaxDefinitionLevel, + valuesMaxDefinitionLevel); + } + + @DataProvider + public static Object[][] mapBlockProvider() + { + return Stream.of(MapBlockProvider.values()) + .collect(toDataProvider()); + } + + private enum MapBlockProvider + { + NO_NULLS { + @Override + Block getInputBlock() + { + return createMapBlock(Optional.empty()); + } + }, + NO_NULLS_WITH_MAY_HAVE_NULL { + @Override + Block getInputBlock() + { + return createMapBlock(Optional.of(new boolean[POSITIONS])); + } + }, + ALL_NULLS { + @Override + Block getInputBlock() + { + return createMapBlock(Optional.of(ALL_NULLS_ARRAY)); + } + }, + RANDOM_NULLS { + @Override + Block getInputBlock() + { + return createMapBlock(Optional.of(RANDOM_NULLS_ARRAY)); + } + }, + GROUPED_NULLS { + @Override + Block getInputBlock() + { + return createMapBlock(Optional.of(GROUPED_NULLS_ARRAY)); + } + }; + + abstract Block getInputBlock(); + + private static Block createMapBlock(Optional mapIsNull) + { + int[] offsets = generateOffsets(mapIsNull); + int positionCount = offsets[POSITIONS]; + Block keyBlock = new LongArrayBlock(positionCount, Optional.empty(), new long[positionCount]); + Block valueBlock = createLongsBlockWithRandomNulls(positionCount); + return fromKeyValueBlock(mapIsNull, offsets, keyBlock, valueBlock, new MapType(BIGINT, BIGINT, TYPE_OPERATORS)); + } + } + + private static class TestingValuesWriter + extends ValuesWriter + { + private final IntList values = new IntArrayList(); + + @Override + public long getBufferedSize() + { + throw new UnsupportedOperationException(); + } + + @Override + public BytesInput getBytes() + { + throw new UnsupportedOperationException(); + } + + @Override + public Encoding getEncoding() + { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() + { + throw new UnsupportedOperationException(); + } + + @Override + public long getAllocatedSize() + { + throw new UnsupportedOperationException(); + } + + @Override + public String memUsageString(String prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public void writeInteger(int v) + { + values.add(v); + } + + List getWrittenValues() + { + return values; + } + } + + private static void assertDefinitionLevels(Block block, List writePositionCounts, int maxDefinitionLevel) + { + TestingValuesWriter valuesWriter = new TestingValuesWriter(); + DefinitionLevelWriter primitiveDefLevelWriter = DefLevelWriterProviders.of(block, maxDefinitionLevel) + .getDefinitionLevelWriter(Optional.empty(), valuesWriter); + ValuesCount primitiveValuesCount; + if (writePositionCounts.isEmpty()) { + primitiveValuesCount = primitiveDefLevelWriter.writeDefinitionLevels(); + } + else { + int totalValuesCount = 0; + int maxDefinitionLevelValuesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + ValuesCount valuesCount = primitiveDefLevelWriter.writeDefinitionLevels(1); + totalValuesCount += valuesCount.totalValuesCount(); + maxDefinitionLevelValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + } + primitiveValuesCount = new ValuesCount(totalValuesCount, maxDefinitionLevelValuesCount); + } + + int maxDefinitionValuesCount = 0; + ImmutableList.Builder expectedDefLevelsBuilder = ImmutableList.builder(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (block.isNull(position)) { + expectedDefLevelsBuilder.add(maxDefinitionLevel - 1); + } + else { + expectedDefLevelsBuilder.add(maxDefinitionLevel); + maxDefinitionValuesCount++; + } + } + assertThat(primitiveValuesCount.totalValuesCount()).isEqualTo(block.getPositionCount()); + assertThat(primitiveValuesCount.maxDefinitionLevelValuesCount()).isEqualTo(maxDefinitionValuesCount); + assertThat(valuesWriter.getWrittenValues()).isEqualTo(expectedDefLevelsBuilder.build()); + } + + private static void assertDefinitionLevels( + ColumnarRow columnarRow, + List writePositionCounts, + int field, + int maxDefinitionLevel) + { + // Write definition levels + TestingValuesWriter valuesWriter = new TestingValuesWriter(); + DefinitionLevelWriter fieldRootDefLevelWriter = getRootDefinitionLevelWriter( + ImmutableList.of( + DefLevelWriterProviders.of(columnarRow, maxDefinitionLevel - 1), + DefLevelWriterProviders.of(columnarRow.getField(field), maxDefinitionLevel)), + valuesWriter); + ValuesCount fieldValuesCount; + if (writePositionCounts.isEmpty()) { + fieldValuesCount = fieldRootDefLevelWriter.writeDefinitionLevels(); + } + else { + int totalValuesCount = 0; + int maxDefinitionLevelValuesCount = 0; + for (int positionsCount : writePositionCounts) { + ValuesCount valuesCount = fieldRootDefLevelWriter.writeDefinitionLevels(positionsCount); + totalValuesCount += valuesCount.totalValuesCount(); + maxDefinitionLevelValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + } + fieldValuesCount = new ValuesCount(totalValuesCount, maxDefinitionLevelValuesCount); + } + + // Verify written definition levels + int maxDefinitionValuesCount = 0; + ImmutableList.Builder expectedDefLevelsBuilder = ImmutableList.builder(); + int fieldOffset = 0; + for (int position = 0; position < columnarRow.getPositionCount(); position++) { + if (columnarRow.isNull(position)) { + expectedDefLevelsBuilder.add(maxDefinitionLevel - 2); + continue; + } + Block fieldBlock = columnarRow.getField(field); + if (fieldBlock.isNull(fieldOffset)) { + expectedDefLevelsBuilder.add(maxDefinitionLevel - 1); + } + else { + expectedDefLevelsBuilder.add(maxDefinitionLevel); + maxDefinitionValuesCount++; + } + fieldOffset++; + } + assertThat(fieldValuesCount.totalValuesCount()).isEqualTo(columnarRow.getPositionCount()); + assertThat(fieldValuesCount.maxDefinitionLevelValuesCount()).isEqualTo(maxDefinitionValuesCount); + assertThat(valuesWriter.getWrittenValues()).isEqualTo(expectedDefLevelsBuilder.build()); + } + + private static void assertDefinitionLevels( + ColumnarArray columnarArray, + List writePositionCounts, + int maxDefinitionLevel) + { + // Write definition levels + TestingValuesWriter valuesWriter = new TestingValuesWriter(); + DefinitionLevelWriter elementsRootDefLevelWriter = getRootDefinitionLevelWriter( + ImmutableList.of( + DefLevelWriterProviders.of(columnarArray, maxDefinitionLevel - 1), + DefLevelWriterProviders.of(columnarArray.getElementsBlock(), maxDefinitionLevel)), + valuesWriter); + ValuesCount elementsValuesCount; + if (writePositionCounts.isEmpty()) { + elementsValuesCount = elementsRootDefLevelWriter.writeDefinitionLevels(); + } + else { + int totalValuesCount = 0; + int maxDefinitionLevelValuesCount = 0; + for (int positionsCount : writePositionCounts) { + ValuesCount valuesCount = elementsRootDefLevelWriter.writeDefinitionLevels(positionsCount); + totalValuesCount += valuesCount.totalValuesCount(); + maxDefinitionLevelValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + } + elementsValuesCount = new ValuesCount(totalValuesCount, maxDefinitionLevelValuesCount); + } + + // Verify written definition levels + int maxDefinitionValuesCount = 0; + int totalValuesCount = 0; + ImmutableList.Builder expectedDefLevelsBuilder = ImmutableList.builder(); + int elementsOffset = 0; + for (int position = 0; position < columnarArray.getPositionCount(); position++) { + if (columnarArray.isNull(position)) { + expectedDefLevelsBuilder.add(maxDefinitionLevel - 3); + totalValuesCount++; + continue; + } + int arrayLength = columnarArray.getLength(position); + if (arrayLength == 0) { + expectedDefLevelsBuilder.add(maxDefinitionLevel - 2); + totalValuesCount++; + continue; + } + totalValuesCount += arrayLength; + Block elementsBlock = columnarArray.getElementsBlock(); + for (int i = elementsOffset; i < elementsOffset + arrayLength; i++) { + if (elementsBlock.isNull(i)) { + expectedDefLevelsBuilder.add(maxDefinitionLevel - 1); + } + else { + expectedDefLevelsBuilder.add(maxDefinitionLevel); + maxDefinitionValuesCount++; + } + } + elementsOffset += arrayLength; + } + assertThat(elementsValuesCount.totalValuesCount()).isEqualTo(totalValuesCount); + assertThat(elementsValuesCount.maxDefinitionLevelValuesCount()).isEqualTo(maxDefinitionValuesCount); + assertThat(valuesWriter.getWrittenValues()).isEqualTo(expectedDefLevelsBuilder.build()); + } + + private static void assertDefinitionLevels( + ColumnarMap columnarMap, + List writePositionCounts, + int keysMaxDefinitionLevel, + int valuesMaxDefinitionLevel) + { + // Write definition levels for map keys + TestingValuesWriter keysWriter = new TestingValuesWriter(); + DefinitionLevelWriter keysRootDefLevelWriter = getRootDefinitionLevelWriter( + ImmutableList.of( + DefLevelWriterProviders.of(columnarMap, keysMaxDefinitionLevel), + DefLevelWriterProviders.of(columnarMap.getKeysBlock(), keysMaxDefinitionLevel)), + keysWriter); + ValuesCount keysValueCount; + if (writePositionCounts.isEmpty()) { + keysValueCount = keysRootDefLevelWriter.writeDefinitionLevels(); + } + else { + int totalValuesCount = 0; + int maxDefinitionLevelValuesCount = 0; + for (int positionsCount : writePositionCounts) { + ValuesCount valuesCount = keysRootDefLevelWriter.writeDefinitionLevels(positionsCount); + totalValuesCount += valuesCount.totalValuesCount(); + maxDefinitionLevelValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + } + keysValueCount = new ValuesCount(totalValuesCount, maxDefinitionLevelValuesCount); + } + + // Write definition levels for map values + TestingValuesWriter valuesWriter = new TestingValuesWriter(); + DefinitionLevelWriter valuesRootDefLevelWriter = getRootDefinitionLevelWriter( + ImmutableList.of( + DefLevelWriterProviders.of(columnarMap, keysMaxDefinitionLevel), + DefLevelWriterProviders.of(columnarMap.getValuesBlock(), valuesMaxDefinitionLevel)), + valuesWriter); + ValuesCount valuesValueCount; + if (writePositionCounts.isEmpty()) { + valuesValueCount = valuesRootDefLevelWriter.writeDefinitionLevels(); + } + else { + int totalValuesCount = 0; + int maxDefinitionLevelValuesCount = 0; + for (int positionsCount : writePositionCounts) { + ValuesCount valuesCount = valuesRootDefLevelWriter.writeDefinitionLevels(positionsCount); + totalValuesCount += valuesCount.totalValuesCount(); + maxDefinitionLevelValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + } + valuesValueCount = new ValuesCount(totalValuesCount, maxDefinitionLevelValuesCount); + } + + // Verify written definition levels + int maxDefinitionKeysCount = 0; + int maxDefinitionValuesCount = 0; + int totalValuesCount = 0; + ImmutableList.Builder keysExpectedDefLevelsBuilder = ImmutableList.builder(); + ImmutableList.Builder valuesExpectedDefLevelsBuilder = ImmutableList.builder(); + int valuesOffset = 0; + for (int position = 0; position < columnarMap.getPositionCount(); position++) { + if (columnarMap.isNull(position)) { + keysExpectedDefLevelsBuilder.add(keysMaxDefinitionLevel - 2); + valuesExpectedDefLevelsBuilder.add(valuesMaxDefinitionLevel - 3); + totalValuesCount++; + continue; + } + int mapLength = columnarMap.getEntryCount(position); + if (mapLength == 0) { + keysExpectedDefLevelsBuilder.add(keysMaxDefinitionLevel - 1); + valuesExpectedDefLevelsBuilder.add(valuesMaxDefinitionLevel - 2); + totalValuesCount++; + continue; + } + totalValuesCount += mapLength; + // Map keys cannot be null + keysExpectedDefLevelsBuilder.addAll(nCopies(mapLength, keysMaxDefinitionLevel)); + maxDefinitionKeysCount += mapLength; + Block valuesBlock = columnarMap.getValuesBlock(); + for (int i = valuesOffset; i < valuesOffset + mapLength; i++) { + if (valuesBlock.isNull(i)) { + valuesExpectedDefLevelsBuilder.add(valuesMaxDefinitionLevel - 1); + } + else { + valuesExpectedDefLevelsBuilder.add(valuesMaxDefinitionLevel); + maxDefinitionValuesCount++; + } + } + valuesOffset += mapLength; + } + assertThat(keysValueCount.totalValuesCount()).isEqualTo(totalValuesCount); + assertThat(keysValueCount.maxDefinitionLevelValuesCount()).isEqualTo(maxDefinitionKeysCount); + assertThat(keysWriter.getWrittenValues()).isEqualTo(keysExpectedDefLevelsBuilder.build()); + + assertThat(valuesValueCount.totalValuesCount()).isEqualTo(totalValuesCount); + assertThat(valuesValueCount.maxDefinitionLevelValuesCount()).isEqualTo(maxDefinitionValuesCount); + assertThat(valuesWriter.getWrittenValues()).isEqualTo(valuesExpectedDefLevelsBuilder.build()); + } + + private static List generateGroupSizes(int positionsCount) + { + int maxGroupSize = 17; + int offset = 0; + ImmutableList.Builder groupsBuilder = ImmutableList.builder(); + while (offset < positionsCount) { + int remaining = positionsCount - offset; + int groupSize = Math.min(RANDOM.nextInt(maxGroupSize) + 1, remaining); + groupsBuilder.add(groupSize); + offset += groupSize; + } + return groupsBuilder.build(); + } + + private static int[] generateOffsets(Optional valueIsNull) + { + int maxCardinality = 7; // array length or map size at the current position + int[] offsets = new int[POSITIONS + 1]; + for (int position = 0; position < POSITIONS; position++) { + if (valueIsNull.isPresent() && valueIsNull.get()[position]) { + offsets[position + 1] = offsets[position]; + } + else { + offsets[position + 1] = offsets[position] + RANDOM.nextInt(maxCardinality); + } + } + return offsets; + } + + private static Block createLongsBlockWithRandomNulls(int positionCount) + { + boolean[] valueIsNull = new boolean[positionCount]; + for (int i = 0; i < positionCount; i++) { + valueIsNull[i] = RANDOM.nextBoolean(); + } + return new LongArrayBlock(positionCount, Optional.of(valueIsNull), new long[positionCount]); + } +} diff --git a/lib/trino-phoenix5-patched/pom.xml b/lib/trino-phoenix5-patched/pom.xml index 05363954d3b9..e0b9f64ee9e6 100644 --- a/lib/trino-phoenix5-patched/pom.xml +++ b/lib/trino-phoenix5-patched/pom.xml @@ -6,7 +6,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-plugin-toolkit/pom.xml b/lib/trino-plugin-toolkit/pom.xml index a688cd219932..25d39428b506 100644 --- a/lib/trino-plugin-toolkit/pom.xml +++ b/lib/trino-plugin-toolkit/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 5f0ec0cf1d3a..e5110a67c470 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -143,14 +143,6 @@ public TableStatisticsMetadata getStatisticsCollectionMetadataForWrite(Connector } } - @Override - public TableStatisticsMetadata getStatisticsCollectionMetadata(ConnectorSession session, ConnectorTableMetadata tableMetadata) - { - try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { - return delegate.getStatisticsCollectionMetadata(session, tableMetadata); - } - } - @Override public ConnectorAnalyzeMetadata getStatisticsCollectionMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, Map analyzeProperties) { @@ -199,14 +191,6 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable } } - @Override - public ConnectorTableHandle getTableHandleForStatisticsCollection(ConnectorSession session, SchemaTableName tableName, Map analyzeProperties) - { - try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { - return delegate.getTableHandleForStatisticsCollection(session, tableName, analyzeProperties); - } - } - @Override public Optional getTableHandleForExecute(ConnectorSession session, ConnectorTableHandle tableHandle, String procedureName, Map executeProperties, RetryMode retryMode) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java index 2d39f782ef1b..3cb3a95ea6f6 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java @@ -336,6 +336,11 @@ public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, { } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, FunctionKind functionKind, CatalogSchemaRoutineName functionName, TrinoPrincipal grantee, boolean grantOption) + { + } + @Override public void checkCanSetCatalogSessionProperty(SystemSecurityContext context, String catalogName, String propertyName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java index 9141f220d818..3b1af0cb4b16 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java @@ -45,7 +45,6 @@ import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.SELECT; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.UPDATE; import static io.trino.plugin.base.util.JsonUtils.parseJson; -import static io.trino.spi.function.FunctionKind.TABLE; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentTable; @@ -606,9 +605,13 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche @Override public void checkCanExecuteFunction(ConnectorSecurityContext context, FunctionKind functionKind, SchemaRoutineName function) { - if (functionKind == TABLE) { - denyExecuteFunction(function.toString()); + switch (functionKind) { + case SCALAR, AGGREGATE, WINDOW: + return; + case TABLE: + denyExecuteFunction(function.toString()); } + throw new UnsupportedOperationException("Unsupported function kind: " + functionKind); } @Override diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java index c316e9b1d02d..475884194b27 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java @@ -61,7 +61,6 @@ import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.UPDATE; import static io.trino.plugin.base.util.JsonUtils.parseJson; import static io.trino.spi.StandardErrorCode.CONFIGURATION_INVALID; -import static io.trino.spi.function.FunctionKind.TABLE; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyCatalogAccess; import static io.trino.spi.security.AccessDeniedException.denyCommentColumn; @@ -83,6 +82,7 @@ import static io.trino.spi.security.AccessDeniedException.denyDropTable; import static io.trino.spi.security.AccessDeniedException.denyDropView; import static io.trino.spi.security.AccessDeniedException.denyExecuteFunction; +import static io.trino.spi.security.AccessDeniedException.denyGrantExecuteFunctionPrivilege; import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; import static io.trino.spi.security.AccessDeniedException.denyGrantSchemaPrivilege; import static io.trino.spi.security.AccessDeniedException.denyGrantTablePrivilege; @@ -118,6 +118,7 @@ import static io.trino.spi.security.AccessDeniedException.denyViewQuery; import static io.trino.spi.security.AccessDeniedException.denyWriteSystemInformationAccess; import static java.lang.String.format; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -812,6 +813,20 @@ public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, { } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, FunctionKind functionKind, CatalogSchemaRoutineName functionName, TrinoPrincipal grantee, boolean grantOption) + { + switch (functionKind) { + case SCALAR, AGGREGATE, WINDOW: + return; + case TABLE: + // TODO (https://github.com/trinodb/trino/issues/12833) implement + String granteeAsString = format("%s '%s'", grantee.getType().name().toLowerCase(ENGLISH), grantee.getName()); + denyGrantExecuteFunctionPrivilege(functionName.toString(), context.getIdentity(), granteeAsString); + } + throw new UnsupportedOperationException("Unsupported function kind: " + functionKind); + } + @Override public void checkCanSetCatalogSessionProperty(SystemSecurityContext context, String catalogName, String propertyName) { @@ -953,9 +968,13 @@ public void checkCanExecuteFunction(SystemSecurityContext systemSecurityContext, @Override public void checkCanExecuteFunction(SystemSecurityContext systemSecurityContext, FunctionKind functionKind, CatalogSchemaRoutineName functionName) { - if (functionKind == TABLE) { - denyExecuteFunction(functionName.toString()); + switch (functionKind) { + case SCALAR, AGGREGATE, WINDOW: + return; + case TABLE: + denyExecuteFunction(functionName.toString()); } + throw new UnsupportedOperationException("Unsupported function kind: " + functionKind); } @Override diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java index 6a7c61f3147f..0f9fc4264932 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java @@ -373,6 +373,12 @@ public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, delegate().checkCanGrantExecuteFunctionPrivilege(context, functionName, grantee, grantOption); } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, FunctionKind functionKind, CatalogSchemaRoutineName functionName, TrinoPrincipal grantee, boolean grantOption) + { + delegate().checkCanGrantExecuteFunctionPrivilege(context, functionKind, functionName, grantee, grantOption); + } + @Override public void checkCanSetCatalogSessionProperty(SystemSecurityContext context, String catalogName, String propertyName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ReadOnlySystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ReadOnlySystemAccessControl.java index 506b172f95f3..6f5676f05f69 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ReadOnlySystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ReadOnlySystemAccessControl.java @@ -14,8 +14,10 @@ package io.trino.plugin.base.security; import io.trino.spi.connector.CatalogSchemaName; +import io.trino.spi.connector.CatalogSchemaRoutineName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.FunctionKind; import io.trino.spi.security.SystemAccessControl; import io.trino.spi.security.SystemAccessControlFactory; import io.trino.spi.security.SystemSecurityContext; @@ -27,6 +29,9 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.security.AccessDeniedException.denyGrantExecuteFunctionPrivilege; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; public class ReadOnlySystemAccessControl @@ -105,6 +110,20 @@ public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, { } + @Override + public void checkCanGrantExecuteFunctionPrivilege(SystemSecurityContext context, FunctionKind functionKind, CatalogSchemaRoutineName functionName, TrinoPrincipal grantee, boolean grantOption) + { + switch (functionKind) { + case SCALAR, AGGREGATE, WINDOW: + return; + case TABLE: + // May not be read-only, so deny + String granteeAsString = format("%s '%s'", grantee.getType().name().toLowerCase(ENGLISH), grantee.getName()); + denyGrantExecuteFunctionPrivilege(functionName.toString(), context.getIdentity(), granteeAsString); + } + throw new UnsupportedOperationException("Unsupported function kind: " + functionKind); + } + @Override public Set filterCatalogs(SystemSecurityContext context, Set catalogs) { diff --git a/lib/trino-rcfile/pom.xml b/lib/trino-rcfile/pom.xml index 3f104432a836..69a2f8b53aa1 100644 --- a/lib/trino-rcfile/pom.xml +++ b/lib/trino-rcfile/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-rcfile/src/main/java/io/trino/rcfile/RcFileDecoderUtils.java b/lib/trino-rcfile/src/main/java/io/trino/rcfile/RcFileDecoderUtils.java index d7d41bf25ea1..ca496ad6d563 100644 --- a/lib/trino-rcfile/src/main/java/io/trino/rcfile/RcFileDecoderUtils.java +++ b/lib/trino-rcfile/src/main/java/io/trino/rcfile/RcFileDecoderUtils.java @@ -155,12 +155,10 @@ public static long findFirstSyncPosition(RcFileDataSource dataSource, long offse long startOfSyncSequence = offset + position + index; return startOfSyncSequence; } - else { - // Otherwise, this is not a match for this region - // Note: this case isn't strictly needed as the loop will exit, but it is - // simpler to explicitly call it out. - return -1; - } + // Otherwise, this is not a match for this region + // Note: this case isn't strictly needed as the loop will exit, but it is + // simpler to explicitly call it out. + return -1; } } return -1; diff --git a/lib/trino-rcfile/src/main/java/io/trino/rcfile/TimestampHolder.java b/lib/trino-rcfile/src/main/java/io/trino/rcfile/TimestampHolder.java index 1049baf75651..c770ec18a18e 100644 --- a/lib/trino-rcfile/src/main/java/io/trino/rcfile/TimestampHolder.java +++ b/lib/trino-rcfile/src/main/java/io/trino/rcfile/TimestampHolder.java @@ -62,11 +62,9 @@ public static BiFunction getFactory(TimestampTy if (type.isShort()) { return (block, position) -> new TimestampHolder(type.getLong(block, position), 0); } - else { - return (block, position) -> { - LongTimestamp longTimestamp = (LongTimestamp) type.getObject(block, position); - return new TimestampHolder(longTimestamp.getEpochMicros(), longTimestamp.getPicosOfMicro()); - }; - } + return (block, position) -> { + LongTimestamp longTimestamp = (LongTimestamp) type.getObject(block, position); + return new TimestampHolder(longTimestamp.getEpochMicros(), longTimestamp.getPicosOfMicro()); + }; } } diff --git a/lib/trino-record-decoder/pom.xml b/lib/trino-record-decoder/pom.xml index 67e3e8cfd8ed..62e30827f51c 100644 --- a/lib/trino-record-decoder/pom.xml +++ b/lib/trino-record-decoder/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroColumnDecoder.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroColumnDecoder.java index f276a5260646..4e65a72421ec 100644 --- a/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroColumnDecoder.java +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroColumnDecoder.java @@ -216,7 +216,7 @@ private static Slice getSlice(Object value, Type type, String columnName) if (value instanceof ByteBuffer) { return Slices.wrappedBuffer((ByteBuffer) value); } - else if (value instanceof GenericFixed) { + if (value instanceof GenericFixed) { return Slices.wrappedBuffer(((GenericFixed) value).bytes()); } } diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroRowDecoderFactory.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroRowDecoderFactory.java index a872ac532ccb..c63410ece71b 100644 --- a/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroRowDecoderFactory.java +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/avro/AvroRowDecoderFactory.java @@ -61,10 +61,8 @@ public RowDecoder create(Map decoderParams, Set dataDecoder = avroDeserializerFactory.create(avroReaderSupplier); return new GenericRecordRowDecoder(dataDecoder, columns); } - else { - AvroReaderSupplier avroReaderSupplier = avroReaderSupplierFactory.create(parsedSchema); - AvroDeserializer dataDecoder = avroDeserializerFactory.create(avroReaderSupplier); - return new SingleValueRowDecoder(dataDecoder, getOnlyElement(columns)); - } + AvroReaderSupplier avroReaderSupplier = avroReaderSupplierFactory.create(parsedSchema); + AvroDeserializer dataDecoder = avroDeserializerFactory.create(avroReaderSupplier); + return new SingleValueRowDecoder(dataDecoder, getOnlyElement(columns)); } } diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/csv/CsvColumnDecoder.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/csv/CsvColumnDecoder.java index f360e341748d..23a80643457b 100644 --- a/lib/trino-record-decoder/src/main/java/io/trino/decoder/csv/CsvColumnDecoder.java +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/csv/CsvColumnDecoder.java @@ -84,55 +84,53 @@ public FieldValueProvider decodeField(String[] tokens) if (columnIndex >= tokens.length) { return nullValueProvider(); } - else { - return new FieldValueProvider() + return new FieldValueProvider() + { + @Override + public boolean isNull() { - @Override - public boolean isNull() - { - return tokens[columnIndex].isEmpty(); - } + return tokens[columnIndex].isEmpty(); + } - @SuppressWarnings("SimplifiableConditionalExpression") - @Override - public boolean getBoolean() - { - try { - return Boolean.parseBoolean(tokens[columnIndex].trim()); - } - catch (NumberFormatException e) { - throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("could not parse value '%s' as '%s' for column '%s'", tokens[columnIndex].trim(), columnType, columnName)); - } + @SuppressWarnings("SimplifiableConditionalExpression") + @Override + public boolean getBoolean() + { + try { + return Boolean.parseBoolean(tokens[columnIndex].trim()); } - - @Override - public long getLong() - { - try { - return Long.parseLong(tokens[columnIndex].trim()); - } - catch (NumberFormatException e) { - throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("could not parse value '%s' as '%s' for column '%s'", tokens[columnIndex].trim(), columnType, columnName)); - } + catch (NumberFormatException e) { + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("could not parse value '%s' as '%s' for column '%s'", tokens[columnIndex].trim(), columnType, columnName)); } + } - @Override - public double getDouble() - { - try { - return Double.parseDouble(tokens[columnIndex].trim()); - } - catch (NumberFormatException e) { - throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("could not parse value '%s' as '%s' for column '%s'", tokens[columnIndex].trim(), columnType, columnName)); - } + @Override + public long getLong() + { + try { + return Long.parseLong(tokens[columnIndex].trim()); } + catch (NumberFormatException e) { + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("could not parse value '%s' as '%s' for column '%s'", tokens[columnIndex].trim(), columnType, columnName)); + } + } - @Override - public Slice getSlice() - { - return truncateToLength(utf8Slice(tokens[columnIndex]), columnType); + @Override + public double getDouble() + { + try { + return Double.parseDouble(tokens[columnIndex].trim()); } - }; - } + catch (NumberFormatException e) { + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("could not parse value '%s' as '%s' for column '%s'", tokens[columnIndex].trim(), columnType, columnName)); + } + } + + @Override + public Slice getSlice() + { + return truncateToLength(utf8Slice(tokens[columnIndex]), columnType); + } + }; } } diff --git a/plugin/trino-accumulo-iterators/pom.xml b/plugin/trino-accumulo-iterators/pom.xml index c8a2693021ab..73a09e71db1c 100644 --- a/plugin/trino-accumulo-iterators/pom.xml +++ b/plugin/trino-accumulo-iterators/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-accumulo/pom.xml b/plugin/trino-accumulo/pom.xml index d9ace39f20ad..13685eac6302 100644 --- a/plugin/trino-accumulo/pom.xml +++ b/plugin/trino-accumulo/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloClient.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloClient.java index 63e0ed04eec4..1134f6cde8dd 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloClient.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloClient.java @@ -852,20 +852,18 @@ private Optional getTabletLocation(String table, Key key) location = Optional.of(entry.getValue().toString()); break; } - else { - // Chop off some magic nonsense - scannedCompareKey.set(keyBytes, 3, keyBytes.length - 3); - - // Compare the keys, moving along the tablets until the location is found - if (scannedCompareKey.getLength() > 0) { - int compareTo = splitCompareKey.compareTo(scannedCompareKey); - if (compareTo <= 0) { - location = Optional.of(entry.getValue().toString()); - } - else { - // all future tablets will be greater than this key - break; - } + // Chop off some magic nonsense + scannedCompareKey.set(keyBytes, 3, keyBytes.length - 3); + + // Compare the keys, moving along the tablets until the location is found + if (scannedCompareKey.getLength() > 0) { + int compareTo = splitCompareKey.compareTo(scannedCompareKey); + if (compareTo <= 0) { + location = Optional.of(entry.getValue().toString()); + } + else { + // all future tablets will be greater than this key + break; } } } diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/index/IndexLookup.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/index/IndexLookup.java index 3dc25663a53c..e2914148e1e9 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/index/IndexLookup.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/index/IndexLookup.java @@ -172,11 +172,9 @@ public boolean applyIndex( return true; } - else { - LOG.debug("Use of index metrics is enabled"); - // Get ranges using the metrics - return getRangesWithMetrics(session, schema, table, constraintRanges, rowIdRanges, tabletSplits, auths); - } + LOG.debug("Use of index metrics is enabled"); + // Get ranges using the metrics + return getRangesWithMetrics(session, schema, table, constraintRanges, rowIdRanges, tabletSplits, auths); } private static Multimap getIndexedConstraintRanges(Collection constraints, AccumuloRowSerializer serializer) @@ -279,10 +277,8 @@ private boolean getRangesWithMetrics( LOG.debug("Number of splits for %s.%s is %d with %d ranges", schema, table, tabletSplits.size(), indexRanges.size()); return true; } - else { - // We are going to do too much work to use the secondary index, so return false - return false; - } + // We are going to do too much work to use the secondary index, so return false + return false; } private static boolean smallestCardAboveThreshold(ConnectorSession session, long numRows, long smallestCardinality) diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloPageSink.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloPageSink.java index aacfde5a14f4..1f48e0fc745d 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloPageSink.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloPageSink.java @@ -92,16 +92,11 @@ public AccumuloPageSink( this.columns = table.getColumns(); // Fetch the row ID ordinal, throwing an exception if not found for safety - Optional ordinal = columns.stream() + this.rowIdOrdinal = columns.stream() .filter(columnHandle -> columnHandle.getName().equals(table.getRowId())) .map(AccumuloColumnHandle::getOrdinal) - .findAny(); - - if (ordinal.isEmpty()) { - throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, "Row ID ordinal not found"); - } - - this.rowIdOrdinal = ordinal.get(); + .findAny() + .orElseThrow(() -> new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, "Row ID ordinal not found")); this.serializer = table.getSerializerInstance(); try { diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloRecordSet.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloRecordSet.java index 09267f7a460d..2b1c6c068423 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloRecordSet.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/io/AccumuloRecordSet.java @@ -122,11 +122,9 @@ private static Authorizations getScanAuthorizations(ConnectorSession session, Ac LOG.debug("scan_auths table property set: %s", auths); return auths; } - else { - Authorizations auths = connector.securityOperations().getUserAuthorizations(username); - LOG.debug("scan_auths table property not set, using user auths: %s", auths); - return auths; - } + Authorizations auths = connector.securityOperations().getUserAuthorizations(username); + LOG.debug("scan_auths table property not set, using user auths: %s", auths); + return auths; } @Override diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java index ec170e37f00e..889b2ff96ce6 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java @@ -609,16 +609,14 @@ static Object readObject(Type type, Block block, int position) Type elementType = Types.getElementType(type); return getArrayFromBlock(elementType, block.getObject(position, Block.class)); } - else if (Types.isMapType(type)) { + if (Types.isMapType(type)) { return getMapFromBlock(type, block.getObject(position, Block.class)); } - else { - if (type.getJavaType() == Slice.class) { - Slice slice = (Slice) TypeUtils.readNativeValue(type, block, position); - return type.equals(VarcharType.VARCHAR) ? slice.toStringUtf8() : slice.getBytes(); - } - - return TypeUtils.readNativeValue(type, block, position); + if (type.getJavaType() == Slice.class) { + Slice slice = (Slice) TypeUtils.readNativeValue(type, block, position); + return type.equals(VarcharType.VARCHAR) ? slice.toStringUtf8() : slice.getBytes(); } + + return TypeUtils.readNativeValue(type, block, position); } } diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/LexicoderRowSerializer.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/LexicoderRowSerializer.java index d4d8613d0434..1d758620ce5c 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/LexicoderRowSerializer.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/LexicoderRowSerializer.java @@ -382,19 +382,17 @@ public static Lexicoder getLexicoder(Type type) if (Types.isArrayType(type)) { return getListLexicoder(type); } - else if (Types.isMapType(type)) { + if (Types.isMapType(type)) { return getMapLexicoder(type); } - else if (type instanceof VarcharType) { + if (type instanceof VarcharType) { return LEXICODER_MAP.get(VARCHAR); } - else { - Lexicoder lexicoder = LEXICODER_MAP.get(type); - if (lexicoder == null) { - throw new TrinoException(NOT_SUPPORTED, "No lexicoder for type " + type); - } - return lexicoder; + Lexicoder lexicoder = LEXICODER_MAP.get(type); + if (lexicoder == null) { + throw new TrinoException(NOT_SUPPORTED, "No lexicoder for type " + type); } + return lexicoder; } private static ListLexicoder getListLexicoder(Type elementType) diff --git a/plugin/trino-atop/pom.xml b/plugin/trino-atop/pom.xml index 85b623827e44..eeca597afaf3 100644 --- a/plugin/trino-atop/pom.xml +++ b/plugin/trino-atop/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-base-jdbc/pom.xml b/plugin/trino-base-jdbc/pom.xml index 190e7306ebd2..0d4da15b3550 100644 --- a/plugin/trino-base-jdbc/pom.xml +++ b/plugin/trino-base-jdbc/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/StandardColumnMappings.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/StandardColumnMappings.java index 60bc4b487ae8..df4a7f56a8b0 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/StandardColumnMappings.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/StandardColumnMappings.java @@ -574,7 +574,7 @@ public static LongReadFunction timestampReadFunction(TimestampType timestampType return (resultSet, columnIndex) -> toTrinoTimestamp(timestampType, resultSet.getObject(columnIndex, LocalDateTime.class)); } - private static ObjectReadFunction longTimestampReadFunction(TimestampType timestampType) + public static ObjectReadFunction longTimestampReadFunction(TimestampType timestampType) { checkArgument(timestampType.getPrecision() > TimestampType.MAX_SHORT_PRECISION && timestampType.getPrecision() <= MAX_LOCAL_DATE_TIME_PRECISION, "Precision is out of range: %s", timestampType.getPrecision()); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index fb50609f33c2..c6eb3d9bbed2 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -1307,9 +1307,7 @@ protected QueryAssert assertConditionallyPushedDown( if (condition) { return queryAssert.isFullyPushedDown(); } - else { - return queryAssert.isNotFullyPushedDown(otherwiseExpected); - } + return queryAssert.isNotFullyPushedDown(otherwiseExpected); } protected void assertConditionallyOrderedPushedDown( diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java index 246b31dbfa1a..2842a0670ab2 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java @@ -66,18 +66,18 @@ public void testCallPattern() Optional.empty())); assertExpressionPattern( - "$like_pattern(a: varchar(n), b: varchar(m))", + "$like(a: varchar(n), b: varchar(m))", new CallPattern( - "$like_pattern", + "$like", List.of( new ExpressionCapture("a", type("varchar", parameter("n"))), new ExpressionCapture("b", type("varchar", parameter("m")))), Optional.empty())); assertExpressionPattern( - "$like_pattern(a: varchar(n), b: varchar(m)): boolean", + "$like(a: varchar(n), b: varchar(m)): boolean", new CallPattern( - "$like_pattern", + "$like", List.of( new ExpressionCapture("a", type("varchar", parameter("n"))), new ExpressionCapture("b", type("varchar", parameter("m")))), diff --git a/plugin/trino-bigquery/pom.xml b/plugin/trino-bigquery/pom.xml index 34c0c16526f4..652e907ad722 100644 --- a/plugin/trino-bigquery/pom.xml +++ b/plugin/trino-bigquery/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java index ce66521ecd80..67a24f9f05de 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java @@ -217,6 +217,18 @@ Job create(JobInfo jobInfo) return bigQuery.create(jobInfo); } + public void executeUpdate(QueryJobConfiguration job) + { + log.debug("Execute query: %s", job.getQuery()); + try { + bigQuery.query(job); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new BigQueryException(BaseHttpServiceException.UNKNOWN_CODE, format("Failed to run the query [%s]", job.getQuery()), e); + } + } + public TableResult query(String sql, boolean useQueryResultsCache, CreateDisposition createDisposition) { log.debug("Execute query: %s", sql); diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryFilterQueryBuilder.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryFilterQueryBuilder.java index 2511fdda0424..f0f80ac0408f 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryFilterQueryBuilder.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryFilterQueryBuilder.java @@ -26,14 +26,13 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.plugin.bigquery.BigQueryUtil.quote; import static io.trino.plugin.bigquery.BigQueryUtil.toBigQueryColumnName; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; public class BigQueryFilterQueryBuilder { - private static final String QUOTE = "`"; - private static final String ESCAPED_QUOTE = "``"; private final TupleDomain tupleDomain; public static Optional buildFilter(TupleDomain tupleDomain) @@ -151,9 +150,4 @@ private Optional toPredicate(String columnName, String operator, Object } return Optional.of(quote(columnName) + " " + operator + " " + valueAsString.get()); } - - private String quote(String name) - { - return QUOTE + name.replace(QUOTE, ESCAPED_QUOTE) + QUOTE; - } } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 6b0e7de7f433..64629addb815 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -17,6 +17,8 @@ import com.google.cloud.bigquery.DatasetId; import com.google.cloud.bigquery.DatasetInfo; import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.QueryJobConfiguration; +import com.google.cloud.bigquery.QueryParameterValue; import com.google.cloud.bigquery.Schema; import com.google.cloud.bigquery.StandardTableDefinition; import com.google.cloud.bigquery.Table; @@ -94,7 +96,9 @@ import static io.trino.plugin.bigquery.BigQueryTableHandle.getPartitionType; import static io.trino.plugin.bigquery.BigQueryType.toField; import static io.trino.plugin.bigquery.BigQueryUtil.isWildcardTable; +import static io.trino.plugin.bigquery.BigQueryUtil.quote; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -502,6 +506,44 @@ public Optional finishInsert(ConnectorSession session, return Optional.empty(); } + @Override + public void setTableComment(ConnectorSession session, ConnectorTableHandle tableHandle, Optional newComment) + { + BigQueryTableHandle table = (BigQueryTableHandle) tableHandle; + BigQueryClient client = bigQueryClientFactory.createBigQueryClient(session); + + RemoteTableName remoteTableName = table.asPlainTable().getRemoteTableName(); + String sql = format( + "ALTER TABLE %s.%s.%s SET OPTIONS (description = ?)", + quote(remoteTableName.getProjectId()), + quote(remoteTableName.getDatasetName()), + quote(remoteTableName.getTableName())); + client.executeUpdate(QueryJobConfiguration.newBuilder(sql) + .setQuery(sql) + .addPositionalParameter(QueryParameterValue.string(newComment.orElse(null))) + .build()); + } + + @Override + public void setColumnComment(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle, Optional newComment) + { + BigQueryTableHandle table = (BigQueryTableHandle) tableHandle; + BigQueryColumnHandle column = (BigQueryColumnHandle) columnHandle; + BigQueryClient client = bigQueryClientFactory.createBigQueryClient(session); + + RemoteTableName remoteTableName = table.asPlainTable().getRemoteTableName(); + String sql = format( + "ALTER TABLE %s.%s.%s ALTER COLUMN %s SET OPTIONS (description = ?)", + quote(remoteTableName.getProjectId()), + quote(remoteTableName.getDatasetName()), + quote(remoteTableName.getTableName()), + column.getName()); + client.executeUpdate(QueryJobConfiguration.newBuilder(sql) + .setQuery(sql) + .addPositionalParameter(QueryParameterValue.string(newComment.orElse(null))) + .build()); + } + @Override public Optional> applyProjection( ConnectorSession session, diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryUtil.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryUtil.java index e98bcc0dba01..7269662ab766 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryUtil.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryUtil.java @@ -30,6 +30,9 @@ public final class BigQueryUtil { + private static final String QUOTE = "`"; + private static final String ESCAPED_QUOTE = "``"; + private static final Set INTERNAL_ERROR_MESSAGES = ImmutableSet.of( "HTTP/2 error code: INTERNAL_ERROR", "Connection closed with unknown cause", @@ -73,4 +76,9 @@ public static boolean isWildcardTable(TableDefinition.Type type, String tableNam { return type == TABLE && tableName.contains("*"); } + + public static String quote(String name) + { + return QUOTE + name.replace(QUOTE, ESCAPED_QUOTE) + QUOTE; + } } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ReadSessionCreator.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ReadSessionCreator.java index 7d174b0824bc..4ae570a34d3e 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ReadSessionCreator.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ReadSessionCreator.java @@ -111,10 +111,8 @@ private TableInfo getActualTable( // get it from the view return client.getCachedTable(viewExpiration, remoteTable, requiredColumns); } - else { - // not regular table or a view - throw new TrinoException(NOT_SUPPORTED, format("Table type '%s' of table '%s.%s' is not supported", - tableType, remoteTable.getTableId().getDataset(), remoteTable.getTableId().getTable())); - } + // not regular table or a view + throw new TrinoException(NOT_SUPPORTED, format("Table type '%s' of table '%s.%s' is not supported", + tableType, remoteTable.getTableId().getDataset(), remoteTable.getTableId().getTable())); } } diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryConnectorTest.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryConnectorTest.java index 37cb7d978260..36ccabf7a364 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryConnectorTest.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryConnectorTest.java @@ -74,8 +74,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_DELETE: case SUPPORTS_ADD_COLUMN: case SUPPORTS_RENAME_COLUMN: - case SUPPORTS_COMMENT_ON_TABLE: - case SUPPORTS_COMMENT_ON_COLUMN: case SUPPORTS_NEGATIVE_DATE: case SUPPORTS_ARRAY: case SUPPORTS_ROW_TYPE: diff --git a/plugin/trino-blackhole/pom.xml b/plugin/trino-blackhole/pom.xml index 732b1b2815ce..1070d4106b58 100644 --- a/plugin/trino-blackhole/pom.xml +++ b/plugin/trino-blackhole/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java index e9bd4d219afa..49f8ca3b135d 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java @@ -67,10 +67,8 @@ public Page getNextPage() if (pageProcessingDelayInMillis == 0) { return page; } - else { - currentPage = toCompletableFuture(executorService.schedule(() -> page, pageProcessingDelayInMillis, MILLISECONDS)); - return null; - } + currentPage = toCompletableFuture(executorService.schedule(() -> page, pageProcessingDelayInMillis, MILLISECONDS)); + return null; } @Override diff --git a/plugin/trino-cassandra/pom.xml b/plugin/trino-cassandra/pom.xml index 8d832e6d98f2..623a6f95f60d 100644 --- a/plugin/trino-cassandra/pom.xml +++ b/plugin/trino-cassandra/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index 9bbdad28d956..18a6b652ca52 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -417,11 +417,9 @@ public Optional applyDelete(ConnectorSession session, Conn public OptionalLong executeDelete(ConnectorSession session, ConnectorTableHandle deleteHandle) { CassandraTableHandle handle = (CassandraTableHandle) deleteHandle; - Optional> partitions = handle.getPartitions(); + List partitions = handle.getPartitions() + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "Deleting without partition key is not supported")); if (partitions.isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, "Deleting without partition key is not supported"); - } - if (partitions.get().isEmpty()) { // there are no records of a given partition key return OptionalLong.empty(); } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java index ec0e8fcec1f1..1ecda23397d9 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java @@ -570,17 +570,15 @@ private T executeWithSession(SessionCallable sessionCallable) if (timeLeft <= 0) { throw e; } - else { - long delay = Math.min(schedule.nextDelay().toMillis(), timeLeft); - log.warn(e.getMessage()); - log.warn("Reconnecting in %dms", delay); - try { - Thread.sleep(delay); - } - catch (InterruptedException interrupted) { - Thread.currentThread().interrupt(); - throw new RuntimeException("interrupted", interrupted); - } + long delay = Math.min(schedule.nextDelay().toMillis(), timeLeft); + log.warn(e.getMessage()); + log.warn("Reconnecting in %dms", delay); + try { + Thread.sleep(delay); + } + catch (InterruptedException interrupted) { + Thread.currentThread().interrupt(); + throw new RuntimeException("interrupted", interrupted); } } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplit.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplit.java index adfa7d3e2f8c..698b78d8b07d 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplit.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplit.java @@ -107,17 +107,11 @@ public String getWhereClause() if (splitCondition != null) { return " WHERE " + splitCondition; } - else { - return ""; - } + return ""; } - else { - if (splitCondition != null) { - return " WHERE " + partitionId + " AND " + splitCondition; - } - else { - return " WHERE " + partitionId; - } + if (splitCondition != null) { + return " WHERE " + partitionId + " AND " + splitCondition; } + return " WHERE " + partitionId; } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java index 42228051b096..3c1280ab5916 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java @@ -47,9 +47,7 @@ public BigInteger getTokenCountInRange(Token startToken, Token endToken) if (start == MIN_TOKEN) { return TOTAL_TOKEN_COUNT; } - else { - return ZERO; - } + return ZERO; } BigInteger result = BigInteger.valueOf(end).subtract(BigInteger.valueOf(start)); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java index ef8d5acdcaef..4e3d2a2570d7 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java @@ -50,9 +50,7 @@ public BigInteger getTokenCountInRange(Token startToken, Token endToken) if (start.equals(MIN_TOKEN)) { return TOTAL_TOKEN_COUNT; } - else { - return ZERO; - } + return ZERO; } BigInteger result = end.subtract(start); diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java index 2d6da5d952db..c47fc5247856 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java @@ -57,7 +57,7 @@ public class CassandraServer implements Closeable { - private static Logger log = Logger.get(CassandraServer.class); + private static final Logger log = Logger.get(CassandraServer.class); private static final int PORT = 9142; @@ -99,7 +99,7 @@ public CassandraServer(String imageName, Map environmentVariable CqlSessionBuilder cqlSessionBuilder = CqlSession.builder() .withApplicationName("TestCluster") - .addContactPoint(new InetSocketAddress(this.dockerContainer.getContainerIpAddress(), this.dockerContainer.getMappedPort(PORT))) + .addContactPoint(new InetSocketAddress(this.dockerContainer.getHost(), this.dockerContainer.getMappedPort(PORT))) .withLocalDatacenter("datacenter1") .withConfigLoader(driverConfigLoaderBuilder.build()); @@ -146,7 +146,7 @@ public CassandraSession getSession() public String getHost() { - return dockerContainer.getContainerIpAddress(); + return dockerContainer.getHost(); } public int getPort() diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java index 75e8ccbf4ad4..5888408b06d5 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java @@ -52,13 +52,11 @@ public class TestingScyllaServer private final CassandraSession session; public TestingScyllaServer() - throws Exception { this("2.2.0"); } public TestingScyllaServer(String version) - throws Exception { container = new GenericContainer<>("scylladb/scylla:" + version) .withCommand("--smp", "1") // Limit SMP to run in a machine having many cores https://github.com/scylladb/scylla/issues/5638 @@ -74,7 +72,7 @@ public TestingScyllaServer(String version) CqlSessionBuilder cqlSessionBuilder = CqlSession.builder() .withApplicationName("TestCluster") - .addContactPoint(new InetSocketAddress(this.container.getContainerIpAddress(), this.container.getMappedPort(PORT))) + .addContactPoint(new InetSocketAddress(this.container.getHost(), this.container.getMappedPort(PORT))) .withLocalDatacenter("datacenter1") .withConfigLoader(config.build()); @@ -92,7 +90,7 @@ public CassandraSession getSession() public String getHost() { - return container.getContainerIpAddress(); + return container.getHost(); } public int getPort() diff --git a/plugin/trino-clickhouse/pom.xml b/plugin/trino-clickhouse/pom.xml index 128516fef660..11066d1aca6e 100644 --- a/plugin/trino-clickhouse/pom.xml +++ b/plugin/trino-clickhouse/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index d577be7a3d4d..05d25944494f 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -15,6 +15,7 @@ import com.clickhouse.client.ClickHouseColumn; import com.clickhouse.client.ClickHouseDataType; +import com.clickhouse.client.ClickHouseVersion; import com.google.common.base.Enums; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; @@ -80,13 +81,13 @@ import java.sql.Types; import java.time.LocalDate; import java.time.LocalDateTime; -import java.time.format.DateTimeFormatter; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalLong; import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import static com.google.common.base.Preconditions.checkArgument; @@ -101,6 +102,11 @@ import static io.trino.plugin.clickhouse.ClickHouseTableProperties.PARTITION_BY_PROPERTY; import static io.trino.plugin.clickhouse.ClickHouseTableProperties.PRIMARY_KEY_PROPERTY; import static io.trino.plugin.clickhouse.ClickHouseTableProperties.SAMPLE_BY_PROPERTY; +import static io.trino.plugin.clickhouse.TrinoToClickHouseWriteChecker.DATETIME; +import static io.trino.plugin.clickhouse.TrinoToClickHouseWriteChecker.UINT16; +import static io.trino.plugin.clickhouse.TrinoToClickHouseWriteChecker.UINT32; +import static io.trino.plugin.clickhouse.TrinoToClickHouseWriteChecker.UINT64; +import static io.trino.plugin.clickhouse.TrinoToClickHouseWriteChecker.UINT8; import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding; @@ -133,7 +139,6 @@ import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; @@ -169,34 +174,18 @@ public class ClickHouseClient { private static final Splitter TABLE_PROPERTY_SPLITTER = Splitter.on(',').omitEmptyStrings().trimResults(); - private static final long UINT8_MIN_VALUE = 0L; - private static final long UINT8_MAX_VALUE = 255L; - - private static final long UINT16_MIN_VALUE = 0L; - private static final long UINT16_MAX_VALUE = 65535L; - - private static final long UINT32_MIN_VALUE = 0L; - private static final long UINT32_MAX_VALUE = 4294967295L; - private static final DecimalType UINT64_TYPE = createDecimalType(20, 0); - private static final BigDecimal UINT64_MIN_VALUE = BigDecimal.ZERO; - private static final BigDecimal UINT64_MAX_VALUE = new BigDecimal("18446744073709551615"); - - private static final long MIN_SUPPORTED_DATE_EPOCH = LocalDate.parse("1970-01-01").toEpochDay(); - private static final long MAX_SUPPORTED_DATE_EPOCH = LocalDate.parse("2106-02-07").toEpochDay(); // The max date is '2148-12-31' in new ClickHouse version - - private static final LocalDateTime MIN_SUPPORTED_TIMESTAMP = LocalDateTime.parse("1970-01-01T00:00:00"); - private static final LocalDateTime MAX_SUPPORTED_TIMESTAMP = LocalDateTime.parse("2105-12-31T23:59:59"); - private static final long MIN_SUPPORTED_TIMESTAMP_EPOCH = MIN_SUPPORTED_TIMESTAMP.toEpochSecond(UTC); - private static final long MAX_SUPPORTED_TIMESTAMP_EPOCH = MAX_SUPPORTED_TIMESTAMP.toEpochSecond(UTC); // An empty character means that the table doesn't have a comment in ClickHouse private static final String NO_COMMENT = ""; + public static final int CLICK_HOUSE_MAX_LIST_EXPRESSIONS = 1_000; + private final ConnectorExpressionRewriter connectorExpressionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; private final Type uuidType; private final Type ipAddressType; + private final AtomicReference clickHouseVersion = new AtomicReference<>(); @Inject public ClickHouseClient( @@ -511,16 +500,16 @@ public Optional toColumnMapping(ConnectorSession session, Connect ClickHouseDataType columnDataType = column.getDataType(); switch (columnDataType) { case UInt8: - return Optional.of(ColumnMapping.longMapping(SMALLINT, ResultSet::getShort, uInt8WriteFunction())); + return Optional.of(ColumnMapping.longMapping(SMALLINT, ResultSet::getShort, uInt8WriteFunction(getClickHouseServerVersion(session)))); case UInt16: - return Optional.of(ColumnMapping.longMapping(INTEGER, ResultSet::getInt, uInt16WriteFunction())); + return Optional.of(ColumnMapping.longMapping(INTEGER, ResultSet::getInt, uInt16WriteFunction(getClickHouseServerVersion(session)))); case UInt32: - return Optional.of(ColumnMapping.longMapping(BIGINT, ResultSet::getLong, uInt32WriteFunction())); + return Optional.of(ColumnMapping.longMapping(BIGINT, ResultSet::getLong, uInt32WriteFunction(getClickHouseServerVersion(session)))); case UInt64: return Optional.of(ColumnMapping.objectMapping( UINT64_TYPE, longDecimalReadFunction(UINT64_TYPE, UNNECESSARY), - uInt64WriteFunction())); + uInt64WriteFunction(getClickHouseServerVersion(session)))); case IPv4: return Optional.of(ipAddressColumnMapping("IPv4StringToNum(?)")); case IPv6: @@ -595,7 +584,7 @@ public Optional toColumnMapping(ConnectorSession session, Connect DISABLE_PUSHDOWN)); case Types.DATE: - return Optional.of(dateColumnMappingUsingLocalDate()); + return Optional.of(dateColumnMappingUsingLocalDate(getClickHouseServerVersion(session))); case Types.TIMESTAMP: if (columnDataType == ClickHouseDataType.DateTime) { @@ -603,7 +592,7 @@ public Optional toColumnMapping(ConnectorSession session, Connect return Optional.of(ColumnMapping.longMapping( TIMESTAMP_SECONDS, timestampReadFunction(TIMESTAMP_SECONDS), - timestampSecondsWriteFunction())); + timestampSecondsWriteFunction(getClickHouseServerVersion(session)))); } // TODO (https://github.com/trinodb/trino/issues/10537) Add support for Datetime64 type return Optional.of(timestampColumnMappingUsingSqlTimestampWithRounding(TIMESTAMP_MILLIS)); @@ -658,10 +647,10 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) return WriteMapping.sliceMapping("String", varbinaryWriteFunction()); } if (type == DATE) { - return WriteMapping.longMapping("Date", dateWriteFunctionUsingLocalDate()); + return WriteMapping.longMapping("Date", dateWriteFunctionUsingLocalDate(getClickHouseServerVersion(session))); } if (type == TIMESTAMP_SECONDS) { - return WriteMapping.longMapping("DateTime", timestampSecondsWriteFunction()); + return WriteMapping.longMapping("DateTime", timestampSecondsWriteFunction(getClickHouseServerVersion(session))); } if (type.equals(uuidType)) { return WriteMapping.sliceMapping("UUID", uuidWriteFunction()); @@ -669,6 +658,27 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type); } + private ClickHouseVersion getClickHouseServerVersion(ConnectorSession session) + { + return clickHouseVersion.updateAndGet(current -> { + if (current != null) { + return current; + } + + try (Connection connection = connectionFactory.openConnection(session); + PreparedStatement statement = connection.prepareStatement("SELECT version()"); + ResultSet resultSet = statement.executeQuery()) { + if (resultSet.next()) { + current = ClickHouseVersion.of(resultSet.getString(1)); + } + return current; + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + }); + } + /** * format property to match ClickHouse create table statement * @@ -688,40 +698,34 @@ private Optional formatProperty(List prop) return Optional.of("(" + String.join(",", prop) + ")"); } - private static LongWriteFunction uInt8WriteFunction() + private static LongWriteFunction uInt8WriteFunction(ClickHouseVersion version) { return (statement, index, value) -> { // ClickHouse stores incorrect results when the values are out of supported range. - if (value < UINT8_MIN_VALUE || value > UINT8_MAX_VALUE) { - throw new TrinoException(INVALID_ARGUMENTS, format("Value must be between %s and %s in ClickHouse: %s", UINT8_MIN_VALUE, UINT8_MAX_VALUE, value)); - } + UINT8.validate(version, value); statement.setShort(index, Shorts.checkedCast(value)); }; } - private static LongWriteFunction uInt16WriteFunction() + private static LongWriteFunction uInt16WriteFunction(ClickHouseVersion version) { return (statement, index, value) -> { // ClickHouse stores incorrect results when the values are out of supported range. - if (value < UINT16_MIN_VALUE || value > UINT16_MAX_VALUE) { - throw new TrinoException(INVALID_ARGUMENTS, format("Value must be between %s and %s in ClickHouse: %s", UINT16_MIN_VALUE, UINT16_MAX_VALUE, value)); - } + UINT16.validate(version, value); statement.setInt(index, toIntExact(value)); }; } - private static LongWriteFunction uInt32WriteFunction() + private static LongWriteFunction uInt32WriteFunction(ClickHouseVersion version) { return (preparedStatement, parameterIndex, value) -> { // ClickHouse stores incorrect results when the values are out of supported range. - if (value < UINT32_MIN_VALUE || value > UINT32_MAX_VALUE) { - throw new TrinoException(INVALID_ARGUMENTS, format("Value must be between %s and %s in ClickHouse: %s", UINT32_MIN_VALUE, UINT32_MAX_VALUE, value)); - } + UINT32.validate(version, value); preparedStatement.setLong(parameterIndex, value); }; } - private static ObjectWriteFunction uInt64WriteFunction() + private static ObjectWriteFunction uInt64WriteFunction(ClickHouseVersion version) { return ObjectWriteFunction.of( Int128.class, @@ -729,56 +733,42 @@ private static ObjectWriteFunction uInt64WriteFunction() BigInteger unscaledValue = value.toBigInteger(); BigDecimal bigDecimal = new BigDecimal(unscaledValue, UINT64_TYPE.getScale(), new MathContext(UINT64_TYPE.getPrecision())); // ClickHouse stores incorrect results when the values are out of supported range. - if (bigDecimal.compareTo(UINT64_MIN_VALUE) < 0 || bigDecimal.compareTo(UINT64_MAX_VALUE) > 0) { - throw new TrinoException(INVALID_ARGUMENTS, format("Value must be between %s and %s in ClickHouse: %s", UINT64_MIN_VALUE, UINT64_MAX_VALUE, bigDecimal)); - } + UINT64.validate(version, bigDecimal); statement.setBigDecimal(index, bigDecimal); }); } - private static ColumnMapping dateColumnMappingUsingLocalDate() + private static ColumnMapping dateColumnMappingUsingLocalDate(ClickHouseVersion version) { return ColumnMapping.longMapping( DATE, dateReadFunctionUsingLocalDate(), - dateWriteFunctionUsingLocalDate()); + dateWriteFunctionUsingLocalDate(version)); } - private static LongWriteFunction dateWriteFunctionUsingLocalDate() + private static LongWriteFunction dateWriteFunctionUsingLocalDate(ClickHouseVersion version) { return (statement, index, value) -> { - verifySupportedDate(value); - statement.setObject(index, LocalDate.ofEpochDay(value)); + LocalDate date = LocalDate.ofEpochDay(value); + // Deny unsupported dates eagerly to prevent unexpected results. ClickHouse stores '1970-01-01' when the date is out of supported range. + TrinoToClickHouseWriteChecker.DATE.validate(version, date); + statement.setObject(index, date); }; } - private static void verifySupportedDate(long value) - { - // Deny unsupported dates eagerly to prevent unexpected results. ClickHouse stores '1970-01-01' when the date is out of supported range. - if (value < MIN_SUPPORTED_DATE_EPOCH || value > MAX_SUPPORTED_DATE_EPOCH) { - throw new TrinoException(INVALID_ARGUMENTS, format("Date must be between %s and %s in ClickHouse: %s", LocalDate.ofEpochDay(MIN_SUPPORTED_DATE_EPOCH), LocalDate.ofEpochDay(MAX_SUPPORTED_DATE_EPOCH), LocalDate.ofEpochDay(value))); - } - } - - private static LongWriteFunction timestampSecondsWriteFunction() + private static LongWriteFunction timestampSecondsWriteFunction(ClickHouseVersion version) { return (statement, index, value) -> { long epochSecond = floorDiv(value, MICROSECONDS_PER_SECOND); int nanoFraction = floorMod(value, MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND; verify(nanoFraction == 0, "Nanos of second must be zero: '%s'", value); - verifySupportedTimestamp(epochSecond); - statement.setObject(index, LocalDateTime.ofEpochSecond(epochSecond, 0, UTC)); + LocalDateTime timestamp = LocalDateTime.ofEpochSecond(epochSecond, 0, UTC); + // ClickHouse stores incorrect results when the values are out of supported range. + DATETIME.validate(version, timestamp); + statement.setObject(index, timestamp); }; } - private static void verifySupportedTimestamp(long epochSecond) - { - if (epochSecond < MIN_SUPPORTED_TIMESTAMP_EPOCH || epochSecond > MAX_SUPPORTED_TIMESTAMP_EPOCH) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss"); - throw new TrinoException(INVALID_ARGUMENTS, format("Timestamp must be between %s and %s in ClickHouse: %s", MIN_SUPPORTED_TIMESTAMP.format(formatter), MAX_SUPPORTED_TIMESTAMP.format(formatter), LocalDateTime.ofEpochSecond(epochSecond, 0, UTC).format(formatter))); - } - } - private ColumnMapping ipAddressColumnMapping(String writeBindExpression) { return ColumnMapping.sliceMapping( diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClientModule.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClientModule.java index c310d1adf729..978249b8519f 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClientModule.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClientModule.java @@ -14,6 +14,7 @@ package io.trino.plugin.clickhouse; import com.google.inject.Binder; +import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; @@ -25,10 +26,13 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; import io.trino.plugin.jdbc.credential.CredentialProvider; import java.sql.Driver; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; +import static io.trino.plugin.clickhouse.ClickHouseClient.CLICK_HOUSE_MAX_LIST_EXPRESSIONS; import static io.trino.plugin.jdbc.JdbcModule.bindSessionPropertiesProvider; import static io.trino.plugin.jdbc.JdbcModule.bindTablePropertiesProvider; @@ -42,6 +46,7 @@ public void configure(Binder binder) bindSessionPropertiesProvider(binder, ClickHouseSessionProperties.class); binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(ClickHouseClient.class).in(Scopes.SINGLETON); bindTablePropertiesProvider(binder, ClickHouseTableProperties.class); + newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(CLICK_HOUSE_MAX_LIST_EXPRESSIONS); binder.install(new DecimalModule()); } diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/TrinoToClickHouseWriteChecker.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/TrinoToClickHouseWriteChecker.java new file mode 100644 index 000000000000..98b240cec92d --- /dev/null +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/TrinoToClickHouseWriteChecker.java @@ -0,0 +1,222 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.clickhouse; + +import com.clickhouse.client.ClickHouseVersion; +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; + +import java.math.BigDecimal; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.time.temporal.ChronoField; +import java.util.List; +import java.util.function.Predicate; + +import static com.google.common.base.Predicates.alwaysTrue; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class TrinoToClickHouseWriteChecker +{ + // Different versions of ClickHouse may support different min/max values for the + // same data type, you can refer to the table below: + // + // | version | column type | min value | max value | + // |---------|-------------|---------------------|----------------------| + // | any | UInt8 | 0 | 255 | + // | any | UInt16 | 0 | 65535 | + // | any | UInt32 | 0 | 4294967295 | + // | any | UInt64 | 0 | 18446744073709551615 | + // | < 21.4 | Date | 1970-01-01 | 2106-02-07 | + // | < 21.4 | DateTime | 1970-01-01 00:00:00 | 2106-02-06 06:28:15 | + // | >= 21.4 | Date | 1970-01-01 | 2149-06-06 | + // | >= 21.4 | DateTime | 1970-01-01 00:00:00 | 2106-02-07 06:28:15 | + // + // And when the value written to ClickHouse is out of range, ClickHouse will store + // the incorrect result, so we need to check the range of the written value to + // prevent ClickHouse from storing the incorrect value. + + public static final TrinoToClickHouseWriteChecker UINT8 = new TrinoToClickHouseWriteChecker<>(ImmutableList.of(new LongWriteValueChecker(alwaysTrue(), new Range<>(0L, 255L)))); + public static final TrinoToClickHouseWriteChecker UINT16 = new TrinoToClickHouseWriteChecker<>(ImmutableList.of(new LongWriteValueChecker(alwaysTrue(), new Range<>(0L, 65535L)))); + public static final TrinoToClickHouseWriteChecker UINT32 = new TrinoToClickHouseWriteChecker<>(ImmutableList.of(new LongWriteValueChecker(alwaysTrue(), new Range<>(0L, 4294967295L)))); + public static final TrinoToClickHouseWriteChecker UINT64 = new TrinoToClickHouseWriteChecker<>( + ImmutableList.of(new BigDecimalWriteValueChecker(alwaysTrue(), new Range<>(BigDecimal.ZERO, new BigDecimal("18446744073709551615"))))); + public static final TrinoToClickHouseWriteChecker DATE = new TrinoToClickHouseWriteChecker<>( + ImmutableList.of( + new DateWriteValueChecker(version -> version.isOlderThan("21.4"), new Range<>(LocalDate.parse("1970-01-01"), LocalDate.parse("2106-02-07"))), + new DateWriteValueChecker(version -> version.isNewerOrEqualTo("21.4"), new Range<>(LocalDate.parse("1970-01-01"), LocalDate.parse("2149-06-06"))))); + public static final TrinoToClickHouseWriteChecker DATETIME = new TrinoToClickHouseWriteChecker<>( + ImmutableList.of( + new TimestampWriteValueChecker( + version -> version.isOlderThan("21.4"), + new Range<>(LocalDateTime.parse("1970-01-01T00:00:00"), LocalDateTime.parse("2106-02-06T06:28:15"))), + new TimestampWriteValueChecker( + version -> version.isNewerOrEqualTo("21.4"), + new Range<>(LocalDateTime.parse("1970-01-01T00:00:00"), LocalDateTime.parse("2106-02-07T06:28:15"))))); + + private final List> checkers; + + private TrinoToClickHouseWriteChecker(List> checkers) + { + this.checkers = ImmutableList.copyOf(requireNonNull(checkers, "checkers is null")); + } + + public void validate(ClickHouseVersion version, T value) + { + for (Checker checker : checkers) { + checker.validate(version, value); + } + } + + private interface Checker + { + void validate(ClickHouseVersion version, T value); + } + + private static class LongWriteValueChecker + implements Checker + { + private final Predicate predicate; + private final Range range; + + public LongWriteValueChecker(Predicate predicate, Range range) + { + this.predicate = requireNonNull(predicate, "predicate is null"); + this.range = requireNonNull(range, "range is null"); + } + + @Override + public void validate(ClickHouseVersion version, Long value) + { + if (!predicate.test(version)) { + return; + } + + if (value >= range.getMin() && value <= range.getMax()) { + return; + } + + throw new TrinoException(INVALID_ARGUMENTS, format("Value must be between %d and %d in ClickHouse: %d", range.getMin(), range.getMax(), value)); + } + } + + private static class BigDecimalWriteValueChecker + implements Checker + { + private final Predicate predicate; + private final Range range; + + public BigDecimalWriteValueChecker(Predicate predicate, Range range) + { + this.predicate = requireNonNull(predicate, "predicate is null"); + this.range = requireNonNull(range, "range is null"); + } + + @Override + public void validate(ClickHouseVersion version, BigDecimal value) + { + if (!predicate.test(version)) { + return; + } + + if (value.compareTo(range.getMin()) >= 0 && value.compareTo(range.getMax()) <= 0) { + return; + } + + throw new TrinoException(INVALID_ARGUMENTS, format("Value must be between %s and %s in ClickHouse: %s", range.getMin(), range.getMax(), value)); + } + } + + private static class DateWriteValueChecker + implements Checker + { + private final Predicate predicate; + private final Range range; + + public DateWriteValueChecker(Predicate predicate, Range range) + { + this.predicate = requireNonNull(predicate, "predicate is null"); + this.range = requireNonNull(range, "range is null"); + } + + @Override + public void validate(ClickHouseVersion version, LocalDate value) + { + if (!predicate.test(version)) { + return; + } + + if (value.isBefore(range.getMin()) || value.isAfter(range.getMax())) { + throw new TrinoException(INVALID_ARGUMENTS, format("Date must be between %s and %s in ClickHouse: %s", range.getMin(), range.getMax(), value)); + } + } + } + + private static class TimestampWriteValueChecker + implements Checker + { + private final Predicate predicate; + private final Range range; + + public TimestampWriteValueChecker(Predicate predicate, Range range) + { + this.predicate = requireNonNull(predicate, "predicate is null"); + this.range = requireNonNull(range, "range is null"); + } + + @Override + public void validate(ClickHouseVersion version, LocalDateTime value) + { + if (!predicate.test(version)) { + return; + } + + if (value.isBefore(range.getMin()) || value.isAfter(range.getMax())) { + DateTimeFormatter formatter = new DateTimeFormatterBuilder() + .appendPattern("uuuu-MM-dd HH:mm:ss") + .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) + .toFormatter(); + throw new TrinoException( + INVALID_ARGUMENTS, + format("Timestamp must be between %s and %s in ClickHouse: %s", formatter.format(range.getMin()), formatter.format(range.getMax()), formatter.format(value))); + } + } + } + + private static class Range + { + private final T min; + private final T max; + + public Range(T min, T max) + { + this.min = requireNonNull(min, "min is null"); + this.max = requireNonNull(max, "max is null"); + } + + public T getMin() + { + return min; + } + + public T getMax() + { + return max; + } + } +} diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java index 0e3867ce4781..586ad85ba3f6 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java @@ -641,7 +641,12 @@ public void testDateYearOfEraPredicate() assertQuery("SELECT orderdate FROM orders WHERE orderdate = DATE '1997-09-14'", "VALUES DATE '1997-09-14'"); assertQueryFails( "SELECT * FROM orders WHERE orderdate = DATE '-1996-09-14'", - "Date must be between 1970-01-01 and 2106-02-07 in ClickHouse: -1996-09-14"); + errorMessageForDateYearOfEraPredicate("-1996-09-14")); + } + + protected String errorMessageForDateYearOfEraPredicate(String date) + { + return "Date must be between 1970-01-01 and 2106-02-07 in ClickHouse: " + date; } @Override diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java new file mode 100644 index 000000000000..3e39c763769f --- /dev/null +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java @@ -0,0 +1,887 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.clickhouse; + +import io.trino.Session; +import io.trino.spi.type.TimeZoneKey; +import io.trino.spi.type.UuidType; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.TestingSession; +import io.trino.testing.datatype.CreateAndInsertDataSetup; +import io.trino.testing.datatype.CreateAndTrinoInsertDataSetup; +import io.trino.testing.datatype.CreateAsSelectDataSetup; +import io.trino.testing.datatype.DataSetup; +import io.trino.testing.datatype.SqlDataTypeTest; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.testing.sql.TrinoSqlExecutor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneId; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static io.trino.plugin.clickhouse.ClickHouseQueryRunner.TPCH_SCHEMA; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.type.IpAddressType.IPADDRESS; +import static java.lang.String.format; +import static java.time.ZoneOffset.UTC; + +public abstract class BaseClickHouseTypeMapping + extends AbstractTestQueryFramework +{ + private final ZoneId jvmZone = ZoneId.systemDefault(); + + // no DST in 1970, but has DST in later years (e.g. 2018) + private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); + + // minutes offset change since 1970-01-01, no DST + private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); + + protected TestingClickHouseServer clickhouseServer; + + @BeforeClass + public void setUp() + { + checkState(jvmZone.getId().equals("America/Bahia_Banderas"), "This test assumes certain JVM time zone"); + LocalDate dateOfLocalTimeChangeForwardAtMidnightInJvmZone = LocalDate.of(1970, 1, 1); + checkIsGap(jvmZone, dateOfLocalTimeChangeForwardAtMidnightInJvmZone.atStartOfDay()); + + LocalDate dateOfLocalTimeChangeForwardAtMidnightInSomeZone = LocalDate.of(1983, 4, 1); + checkIsGap(vilnius, dateOfLocalTimeChangeForwardAtMidnightInSomeZone.atStartOfDay()); + LocalDate dateOfLocalTimeChangeBackwardAtMidnightInSomeZone = LocalDate.of(1983, 10, 1); + checkIsDoubled(vilnius, dateOfLocalTimeChangeBackwardAtMidnightInSomeZone.atStartOfDay().minusMinutes(1)); + + LocalDate timeGapInKathmandu = LocalDate.of(1986, 1, 1); + checkIsGap(kathmandu, timeGapInKathmandu.atStartOfDay()); + } + + private static void checkIsGap(ZoneId zone, LocalDateTime dateTime) + { + verify(isGap(zone, dateTime), "Expected %s to be a gap in %s", dateTime, zone); + } + + private static boolean isGap(ZoneId zone, LocalDateTime dateTime) + { + return zone.getRules().getValidOffsets(dateTime).isEmpty(); + } + + private static void checkIsDoubled(ZoneId zone, LocalDateTime dateTime) + { + verify(zone.getRules().getValidOffsets(dateTime).size() == 2, "Expected %s to be doubled in %s", dateTime, zone); + } + + @Test + public void testTinyint() + { + SqlDataTypeTest.create() + .addRoundTrip("tinyint", "-128", TINYINT, "TINYINT '-128'") // min value in ClickHouse and Trino + .addRoundTrip("tinyint", "5", TINYINT, "TINYINT '5'") + .addRoundTrip("tinyint", "127", TINYINT, "TINYINT '127'") // max value in ClickHouse and Trino + .execute(getQueryRunner(), trinoCreateAsSelect("test_tinyint")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_tinyint")) + + .addRoundTrip("Nullable(tinyint)", "NULL", TINYINT, "CAST(NULL AS TINYINT)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_tinyint")); + + SqlDataTypeTest.create() + .addRoundTrip("tinyint", "NULL", TINYINT, "CAST(NULL AS TINYINT)") + .execute(getQueryRunner(), trinoCreateAsSelect("test_tinyint")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_tinyint")); + } + + @Test + public void testUnsupportedTinyint() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("tinyint", "-129", TINYINT, "TINYINT '127'") + .addRoundTrip("tinyint", "128", TINYINT, "TINYINT '-128'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_tinyint")); + } + + @Test + public void testSmallint() + { + SqlDataTypeTest.create() + .addRoundTrip("smallint", "-32768", SMALLINT, "SMALLINT '-32768'") // min value in ClickHouse and Trino + .addRoundTrip("smallint", "32456", SMALLINT, "SMALLINT '32456'") + .addRoundTrip("smallint", "32767", SMALLINT, "SMALLINT '32767'") // max value in ClickHouse and Trino + .execute(getQueryRunner(), trinoCreateAsSelect("test_smallint")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_smallint")) + + .addRoundTrip("Nullable(smallint)", "NULL", SMALLINT, "CAST(NULL AS SMALLINT)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_smallint")); + + SqlDataTypeTest.create() + .addRoundTrip("smallint", "NULL", SMALLINT, "CAST(NULL AS SMALLINT)") + .execute(getQueryRunner(), trinoCreateAsSelect("test_smallint")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_smallint")); + } + + @Test + public void testUnsupportedSmallint() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("smallint", "-32769", SMALLINT, "SMALLINT '32767'") + .addRoundTrip("smallint", "32768", SMALLINT, "SMALLINT '-32768'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_smallint")); + } + + @Test + public void testInteger() + { + SqlDataTypeTest.create() + .addRoundTrip("integer", "-2147483648", INTEGER, "-2147483648") // min value in ClickHouse and Trino + .addRoundTrip("integer", "1234567890", INTEGER, "1234567890") + .addRoundTrip("integer", "2147483647", INTEGER, "2147483647") // max value in ClickHouse and Trino + .execute(getQueryRunner(), trinoCreateAsSelect("test_int")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_int")) + + .addRoundTrip("Nullable(integer)", "NULL", INTEGER, "CAST(NULL AS INTEGER)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_int")); + + SqlDataTypeTest.create() + .addRoundTrip("integer", "NULL", INTEGER, "CAST(NULL AS INTEGER)") + .execute(getQueryRunner(), trinoCreateAsSelect("test_int")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_int")); + } + + @Test + public void testUnsupportedInteger() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("integer", "-2147483649", INTEGER, "INTEGER '2147483647'") + .addRoundTrip("integer", "2147483648", INTEGER, "INTEGER '-2147483648'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_integer")); + } + + @Test + public void testBigint() + { + SqlDataTypeTest.create() + .addRoundTrip("bigint", "-9223372036854775808", BIGINT, "-9223372036854775808") // min value in ClickHouse and Trino + .addRoundTrip("bigint", "123456789012", BIGINT, "123456789012") + .addRoundTrip("bigint", "9223372036854775807", BIGINT, "9223372036854775807") // max value in ClickHouse and Trino + .execute(getQueryRunner(), trinoCreateAsSelect("test_bigint")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_bigint")) + + .addRoundTrip("Nullable(bigint)", "NULL", BIGINT, "CAST(NULL AS BIGINT)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_bigint")); + + SqlDataTypeTest.create() + .addRoundTrip("bigint", "NULL", BIGINT, "CAST(NULL AS BIGINT)") + .execute(getQueryRunner(), trinoCreateAsSelect("test_bigint")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_bigint")); + } + + @Test + public void testUnsupportedBigint() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("bigint", "-9223372036854775809", BIGINT, "BIGINT '9223372036854775807'") + .addRoundTrip("bigint", "9223372036854775808", BIGINT, "BIGINT '-9223372036854775808'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_bigint")); + } + + @Test + public void testUint8() + { + SqlDataTypeTest.create() + .addRoundTrip("UInt8", "0", SMALLINT, "SMALLINT '0'") // min value in ClickHouse + .addRoundTrip("UInt8", "255", SMALLINT, "SMALLINT '255'") // max value in ClickHouse + .addRoundTrip("Nullable(UInt8)", "NULL", SMALLINT, "CAST(null AS SMALLINT)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint8")); + + SqlDataTypeTest.create() + .addRoundTrip("UInt8", "0", SMALLINT, "SMALLINT '0'") // min value in ClickHouse + .addRoundTrip("UInt8", "255", SMALLINT, "SMALLINT '255'") // max value in ClickHouse + .addRoundTrip("Nullable(UInt8)", "NULL", SMALLINT, "CAST(null AS SMALLINT)") + .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint8")); + } + + @Test + public void testUnsupportedUint8() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("UInt8", "-1", SMALLINT, "SMALLINT '255'") + .addRoundTrip("UInt8", "256", SMALLINT, "SMALLINT '0'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint8")); + + // Prevent writing incorrect results in the connector + try (TestTable table = new TestTable(onRemoteDatabase(), "tpch.test_unsupported_uint8", "(value UInt8) ENGINE=Log")) { + assertQueryFails( + format("INSERT INTO %s VALUES (-1)", table.getName()), + "Value must be between 0 and 255 in ClickHouse: -1"); + assertQueryFails( + format("INSERT INTO %s VALUES (256)", table.getName()), + "Value must be between 0 and 255 in ClickHouse: 256"); + } + } + + @Test + public void testUint16() + { + SqlDataTypeTest.create() + .addRoundTrip("UInt16", "0", INTEGER, "0") // min value in ClickHouse + .addRoundTrip("UInt16", "65535", INTEGER, "65535") // max value in ClickHouse + .addRoundTrip("Nullable(UInt16)", "NULL", INTEGER, "CAST(null AS INTEGER)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint16")); + + SqlDataTypeTest.create() + .addRoundTrip("UInt16", "0", INTEGER, "0") // min value in ClickHouse + .addRoundTrip("UInt16", "65535", INTEGER, "65535") // max value in ClickHouse + .addRoundTrip("Nullable(UInt16)", "NULL", INTEGER, "CAST(null AS INTEGER)") + .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint16")); + } + + @Test + public void testUnsupportedUint16() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("UInt16", "-1", INTEGER, "65535") + .addRoundTrip("UInt16", "65536", INTEGER, "0") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint16")); + + // Prevent writing incorrect results in the connector + try (TestTable table = new TestTable(onRemoteDatabase(), "tpch.test_unsupported_uint16", "(value UInt16) ENGINE=Log")) { + assertQueryFails( + format("INSERT INTO %s VALUES (-1)", table.getName()), + "Value must be between 0 and 65535 in ClickHouse: -1"); + assertQueryFails( + format("INSERT INTO %s VALUES (65536)", table.getName()), + "Value must be between 0 and 65535 in ClickHouse: 65536"); + } + } + + @Test + public void testUint32() + { + SqlDataTypeTest.create() + .addRoundTrip("UInt32", "0", BIGINT, "BIGINT '0'") // min value in ClickHouse + .addRoundTrip("UInt32", "4294967295", BIGINT, "BIGINT '4294967295'") // max value in ClickHouse + .addRoundTrip("Nullable(UInt32)", "NULL", BIGINT, "CAST(null AS BIGINT)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint32")); + + SqlDataTypeTest.create() + .addRoundTrip("UInt32", "BIGINT '0'", BIGINT, "BIGINT '0'") // min value in ClickHouse + .addRoundTrip("UInt32", "BIGINT '4294967295'", BIGINT, "BIGINT '4294967295'") // max value in ClickHouse + .addRoundTrip("Nullable(UInt32)", "NULL", BIGINT, "CAST(null AS BIGINT)") + .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint32")); + } + + @Test + public void testUnsupportedUint32() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("UInt32", "-1", BIGINT, "BIGINT '4294967295'") + .addRoundTrip("UInt32", "4294967296", BIGINT, "BIGINT '0'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint32")); + + // Prevent writing incorrect results in the connector + try (TestTable table = new TestTable(onRemoteDatabase(), "tpch.test_unsupported_uint32", "(value UInt32) ENGINE=Log")) { + assertQueryFails( + format("INSERT INTO %s VALUES (CAST('-1' AS BIGINT))", table.getName()), + "Value must be between 0 and 4294967295 in ClickHouse: -1"); + assertQueryFails( + format("INSERT INTO %s VALUES (CAST('4294967296' AS BIGINT))", table.getName()), + "Value must be between 0 and 4294967295 in ClickHouse: 4294967296"); + } + } + + @Test + public void testUint64() + { + SqlDataTypeTest.create() + .addRoundTrip("UInt64", "0", createDecimalType(20), "CAST('0' AS decimal(20, 0))") // min value in ClickHouse + .addRoundTrip("UInt64", "18446744073709551615", createDecimalType(20), "CAST('18446744073709551615' AS decimal(20, 0))") // max value in ClickHouse + .addRoundTrip("Nullable(UInt64)", "NULL", createDecimalType(20), "CAST(null AS decimal(20, 0))") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint64")); + + SqlDataTypeTest.create() + .addRoundTrip("UInt64", "CAST('0' AS decimal(20, 0))", createDecimalType(20), "CAST('0' AS decimal(20, 0))") // min value in ClickHouse + .addRoundTrip("UInt64", "CAST('18446744073709551615' AS decimal(20, 0))", createDecimalType(20), "CAST('18446744073709551615' AS decimal(20, 0))") // max value in ClickHouse + .addRoundTrip("Nullable(UInt64)", "NULL", createDecimalType(20), "CAST(null AS decimal(20, 0))") + .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint64")); + } + + @Test + public void testUnsupportedUint64() + { + // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. + SqlDataTypeTest.create() + .addRoundTrip("UInt64", "-1", createDecimalType(20), "CAST('18446744073709551615' AS decimal(20, 0))") + .addRoundTrip("UInt64", "18446744073709551616", createDecimalType(20), "CAST('0' AS decimal(20, 0))") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint64")); + + // Prevent writing incorrect results in the connector + try (TestTable table = new TestTable(onRemoteDatabase(), "tpch.test_unsupported_uint64", "(value UInt64) ENGINE=Log")) { + assertQueryFails( + format("INSERT INTO %s VALUES (CAST('-1' AS decimal(20, 0)))", table.getName()), + "Value must be between 0 and 18446744073709551615 in ClickHouse: -1"); + assertQueryFails( + format("INSERT INTO %s VALUES (CAST('18446744073709551616' AS decimal(20, 0)))", table.getName()), + "Value must be between 0 and 18446744073709551615 in ClickHouse: 18446744073709551616"); + } + } + + @Test + public void testReal() + { + SqlDataTypeTest.create() + .addRoundTrip("real", "12.5", REAL, "REAL '12.5'") + .addRoundTrip("real", "nan()", REAL, "CAST(nan() AS REAL)") + .addRoundTrip("real", "-infinity()", REAL, "CAST(-infinity() AS REAL)") + .addRoundTrip("real", "+infinity()", REAL, "CAST(+infinity() AS REAL)") + .addRoundTrip("real", "NULL", REAL, "CAST(NULL AS REAL)") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_real")); + + SqlDataTypeTest.create() + .addRoundTrip("real", "12.5", REAL, "REAL '12.5'") + .addRoundTrip("real", "nan", REAL, "CAST(nan() AS REAL)") + .addRoundTrip("real", "-inf", REAL, "CAST(-infinity() AS REAL)") + .addRoundTrip("real", "+inf", REAL, "CAST(+infinity() AS REAL)") + .addRoundTrip("Nullable(real)", "NULL", REAL, "CAST(NULL AS REAL)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_real")); + } + + @Test + public void testDouble() + { + SqlDataTypeTest.create() + .addRoundTrip("double", "3.1415926835", DOUBLE, "DOUBLE '3.1415926835'") + .addRoundTrip("double", "1.79769E308", DOUBLE, "DOUBLE '1.79769E308'") + .addRoundTrip("double", "2.225E-307", DOUBLE, "DOUBLE '2.225E-307'") + + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_double")) + + .addRoundTrip("double", "NULL", DOUBLE, "CAST(NULL AS DOUBLE)") + + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_double")); + + SqlDataTypeTest.create() + .addRoundTrip("Nullable(double)", "NULL", DOUBLE, "CAST(NULL AS DOUBLE)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.trino_test_nullable_double")); + } + + @Test + public void testDecimal() + { + SqlDataTypeTest.create() + .addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(4, 2)", "CAST('2' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2' AS decimal(4, 2))") + .addRoundTrip("decimal(4, 2)", "CAST('2.3' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2.3' AS decimal(4, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('123456789.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('123456789.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))") + .addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(38, 0)", "CAST('27182818284590452353602874713526624977' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('27182818284590452353602874713526624977' AS decimal(38, 0))") + .addRoundTrip("decimal(38, 0)", "CAST('-27182818284590452353602874713526624977' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('-27182818284590452353602874713526624977' AS decimal(38, 0))") + + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_decimal")) + + .addRoundTrip("decimal(3, 1)", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") + .addRoundTrip("decimal(30, 5)", "NULL", createDecimalType(30, 5), "CAST(NULL AS decimal(30,5))") + + .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); + + SqlDataTypeTest.create() + .addRoundTrip("Nullable(decimal(3, 1))", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") + .addRoundTrip("Nullable(decimal(30, 5))", "NULL", createDecimalType(30, 5), "CAST(NULL AS decimal(30,5))") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_nullable_decimal")); + } + + @Test + public void testClickHouseChar() + { + // ClickHouse char is FixedString, which is arbitrary bytes + SqlDataTypeTest.create() + // plain + .addRoundTrip("char(10)", "'text_a'", VARBINARY, "to_utf8('text_a')") + .addRoundTrip("char(255)", "'text_b'", VARBINARY, "to_utf8('text_b')") + .addRoundTrip("char(5)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") + .addRoundTrip("char(32)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") + .addRoundTrip("char(1)", "'😂'", VARBINARY, "to_utf8('😂')") + .addRoundTrip("char(77)", "'Ну, погоди!'", VARBINARY, "to_utf8('Ну, погоди!')") + // nullable + .addRoundTrip("Nullable(char(10))", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("Nullable(char(10))", "'text_a'", VARBINARY, "to_utf8('text_a')") + .addRoundTrip("Nullable(char(1))", "'😂'", VARBINARY, "to_utf8('😂')") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_char")); + + // Set map_string_as_varchar session property as true + SqlDataTypeTest.create() + // plain + .addRoundTrip("char(10)", "'text_a'", VARCHAR, "CAST('text_a' AS varchar)") + .addRoundTrip("char(255)", "'text_b'", VARCHAR, "CAST('text_b' AS varchar)") + .addRoundTrip("char(5)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") + .addRoundTrip("char(32)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") + .addRoundTrip("char(1)", "'😂'", VARCHAR, "CAST('😂' AS varchar)") + .addRoundTrip("char(77)", "'Ну, погоди!'", VARCHAR, "CAST('Ну, погоди!' AS varchar)") + // nullable + .addRoundTrip("Nullable(char(10))", "NULL", VARCHAR, "CAST(NULL AS varchar)") + .addRoundTrip("Nullable(char(10))", "'text_a'", VARCHAR, "CAST('text_a' AS varchar)") + .addRoundTrip("Nullable(char(1))", "'😂'", VARCHAR, "CAST('😂' AS varchar)") + .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_char")); + } + + @Test + public void testClickHouseFixedString() + { + SqlDataTypeTest.create() + // plain + .addRoundTrip("FixedString(10)", "'c12345678b'", VARBINARY, "to_utf8('c12345678b')") + .addRoundTrip("FixedString(10)", "'c123'", VARBINARY, "to_utf8('c123\0\0\0\0\0\0')") + // nullable + .addRoundTrip("Nullable(FixedString(10))", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("Nullable(FixedString(10))", "'c12345678b'", VARBINARY, "to_utf8('c12345678b')") + .addRoundTrip("Nullable(FixedString(10))", "'c123'", VARBINARY, "to_utf8('c123\0\0\0\0\0\0')") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_fixed_string")); + + // Set map_string_as_varchar session property as true + SqlDataTypeTest.create() + // plain + .addRoundTrip("FixedString(10)", "'c12345678b'", VARCHAR, "CAST('c12345678b' AS varchar)") + .addRoundTrip("FixedString(10)", "'c123'", VARCHAR, "CAST('c123\0\0\0\0\0\0' AS varchar)") + // nullable + .addRoundTrip("Nullable(FixedString(10))", "NULL", VARCHAR, "CAST(NULL AS varchar)") + .addRoundTrip("Nullable(FixedString(10))", "'c12345678b'", VARCHAR, "CAST('c12345678b' AS varchar)") + .addRoundTrip("Nullable(FixedString(10))", "'c123'", VARCHAR, "CAST('c123\0\0\0\0\0\0' AS varchar)") + .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_fixed_string")); + } + + @Test + public void testTrinoChar() + { + SqlDataTypeTest.create() + .addRoundTrip("char(10)", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("char(10)", "'text_a'", VARBINARY, "to_utf8('text_a')") + .addRoundTrip("char(255)", "'text_b'", VARBINARY, "to_utf8('text_b')") + .addRoundTrip("char(5)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") + .addRoundTrip("char(32)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") + .addRoundTrip("char(1)", "'😂'", VARBINARY, "to_utf8('😂')") + .addRoundTrip("char(77)", "'Ну, погоди!'", VARBINARY, "to_utf8('Ну, погоди!')") + .execute(getQueryRunner(), trinoCreateAsSelect("test_char")) + .execute(getQueryRunner(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_char")); + + // Set map_string_as_varchar session property as true + SqlDataTypeTest.create() + .addRoundTrip("char(10)", "NULL", VARCHAR, "CAST(NULL AS varchar)") + .addRoundTrip("char(10)", "'text_a'", VARCHAR, "CAST('text_a' AS varchar)") + .addRoundTrip("char(255)", "'text_b'", VARCHAR, "CAST('text_b' AS varchar)") + .addRoundTrip("char(5)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") + .addRoundTrip("char(32)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") + .addRoundTrip("char(1)", "'😂'", VARCHAR, "CAST('😂' AS varchar)") + .addRoundTrip("char(77)", "'Ну, погоди!'", VARCHAR, "CAST('Ну, погоди!' AS varchar)") + .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect("test_char")) + .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_char")); + } + + @Test + public void testClickHouseVarchar() + { + // TODO add more test cases + // ClickHouse varchar is String, which is arbitrary bytes + SqlDataTypeTest.create() + // plain + .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + // nullable + .addRoundTrip("Nullable(varchar(30))", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("Nullable(varchar(30))", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_varchar")); + + // Set map_string_as_varchar session property as true + SqlDataTypeTest.create() + // plain + .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") + // nullable + .addRoundTrip("Nullable(varchar(30))", "NULL", VARCHAR, "CAST(NULL AS varchar)") + .addRoundTrip("Nullable(varchar(30))", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") + .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_varchar")); + } + + @Test + public void testClickHouseString() + { + // TODO add more test cases + SqlDataTypeTest.create() + // plain + .addRoundTrip("String", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + // nullable + .addRoundTrip("Nullable(String)", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("Nullable(String)", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_varchar")); + + // Set map_string_as_varchar session property as true + SqlDataTypeTest.create() + // plain + .addRoundTrip("String", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") + // nullable + .addRoundTrip("Nullable(String)", "NULL", VARCHAR, "CAST(NULL AS varchar)") + .addRoundTrip("Nullable(String)", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") + .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_varchar")); + } + + @Test + public void testTrinoVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("varchar(30)", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .execute(getQueryRunner(), trinoCreateAsSelect("test_varchar")) + .execute(getQueryRunner(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_varchar")); + + // Set map_string_as_varchar session property as true + SqlDataTypeTest.create() + .addRoundTrip("varchar(30)", "NULL", VARCHAR, "CAST(NULL AS varchar)") + .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") + .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect("test_varchar")) + .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_varchar")); + } + + @Test + public void testTrinoVarbinary() + { + SqlDataTypeTest.create() + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", "X''", VARBINARY, "X''") + .addRoundTrip("varbinary", "X'68656C6C6F'", VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbinary", "X'5069C4996B6E6120C582C4856B61207720E69DB1E4BAACE983BD'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbinary", "X'4261672066756C6C206F6620F09F92B0'", VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "X'0001020304050607080DF9367AA7000000'", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary", "X'000000000000'", VARBINARY, "X'000000000000'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_varbinary")); + } + + @Test(dataProvider = "sessionZonesDataProvider") + public void testDate(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(sessionZone.getId())) + .build(); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '1970-02-03'", DATE, "DATE '1970-02-03'") + .addRoundTrip("date", "DATE '2017-07-01'", DATE, "DATE '2017-07-01'") // summer on northern hemisphere (possible DST) + .addRoundTrip("date", "DATE '2017-01-01'", DATE, "DATE '2017-01-01'") // winter on northern hemisphere (possible DST on southern hemisphere) + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") + .addRoundTrip("date", "DATE '1983-04-01'", DATE, "DATE '1983-04-01'") + .addRoundTrip("date", "DATE '1983-10-01'", DATE, "DATE '1983-10-01'") + .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_date")) + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, trinoCreateAsSelect("test_date")) + .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_date")) + .execute(getQueryRunner(), session, trinoCreateAndInsert("test_date")); + + // Null + SqlDataTypeTest.create() + .addRoundTrip("date", "NULL", DATE, "CAST(NULL AS DATE)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, trinoCreateAsSelect("test_date")) + .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_date")) + .execute(getQueryRunner(), session, trinoCreateAndInsert("test_date")); + SqlDataTypeTest.create() + .addRoundTrip("Nullable(date)", "NULL", DATE, "CAST(NULL AS DATE)") + .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_date")); + } + + @Test(dataProvider = "clickHouseDateMinMaxValuesDataProvider") + public void testClickHouseDateMinMaxValues(String date) + { + SqlDataTypeTest dateTests = SqlDataTypeTest.create() + .addRoundTrip("date", format("DATE '%s'", date), DATE, format("DATE '%s'", date)); + + for (Object[] timeZoneIds : sessionZonesDataProvider()) { + Session session = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(((ZoneId) timeZoneIds[0]).getId())) + .build(); + dateTests + .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_date")) + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, trinoCreateAsSelect("test_date")) + .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_date")) + .execute(getQueryRunner(), session, trinoCreateAndInsert("test_date")); + } + } + + @DataProvider + public Object[][] clickHouseDateMinMaxValuesDataProvider() + { + return new Object[][] { + {"1970-01-01"}, // min value in ClickHouse + {"2106-02-07"}, // max value in ClickHouse + }; + } + + @Test(dataProvider = "unsupportedClickHouseDateValuesDataProvider") + public void testUnsupportedDate(String unsupportedDate) + { + String minSupportedDate = (String) clickHouseDateMinMaxValuesDataProvider()[0][0]; + String maxSupportedDate = (String) clickHouseDateMinMaxValuesDataProvider()[1][0]; + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_unsupported_date", "(dt date)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (DATE '%s')", table.getName(), unsupportedDate), + format("Date must be between %s and %s in ClickHouse: %s", minSupportedDate, maxSupportedDate, unsupportedDate)); + } + + try (TestTable table = new TestTable(onRemoteDatabase(), "tpch.test_unsupported_date", "(dt date) ENGINE=Log")) { + onRemoteDatabase().execute(format("INSERT INTO %s VALUES ('%s')", table.getName(), unsupportedDate)); + assertQuery(format("SELECT dt <> DATE '%s' FROM %s", unsupportedDate, table.getName()), "SELECT true"); // Inserting an unsupported date in ClickHouse will turn it into another date + } + } + + @DataProvider + public Object[][] unsupportedClickHouseDateValuesDataProvider() + { + return new Object[][] { + {"1969-12-31"}, // min - 1 day + {"2106-02-08"}, // max + 1 day + }; + } + + @Test(dataProvider = "sessionZonesDataProvider") + public void testTimestamp(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(sessionZone.getId())) + .build(); + + SqlDataTypeTest.create() + .addRoundTrip("timestamp(0)", "timestamp '1986-01-01 00:13:07'", createTimestampType(0), "TIMESTAMP '1986-01-01 00:13:07'") // time gap in Kathmandu + .addRoundTrip("timestamp(0)", "timestamp '2018-03-25 03:17:17'", createTimestampType(0), "TIMESTAMP '2018-03-25 03:17:17'") // time gap in Vilnius + .addRoundTrip("timestamp(0)", "timestamp '2018-10-28 01:33:17'", createTimestampType(0), "TIMESTAMP '2018-10-28 01:33:17'") // time doubled in JVM zone + .addRoundTrip("timestamp(0)", "timestamp '2018-10-28 03:33:33'", createTimestampType(0), "TIMESTAMP '2018-10-28 03:33:33'") // time double in Vilnius + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_timestamp")) + .execute(getQueryRunner(), session, trinoCreateAsSelect("test_timestamp")) + .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_timestamp")) + .execute(getQueryRunner(), session, trinoCreateAndInsert("test_timestamp")); + + timestampTest("timestamp") + .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_timestamp")); + timestampTest("datetime") + .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_datetime")); + } + + private SqlDataTypeTest timestampTest(String inputType) + { + return unsupportedTimestampBecomeUnexpectedValueTest(inputType) + .addRoundTrip(inputType, "'1986-01-01 00:13:07'", createTimestampType(0), "TIMESTAMP '1986-01-01 00:13:07'") // time gap in Kathmandu + .addRoundTrip(inputType, "'2018-03-25 03:17:17'", createTimestampType(0), "TIMESTAMP '2018-03-25 03:17:17'") // time gap in Vilnius + .addRoundTrip(inputType, "'2018-10-28 01:33:17'", createTimestampType(0), "TIMESTAMP '2018-10-28 01:33:17'") // time doubled in JVM zone + .addRoundTrip(inputType, "'2018-10-28 03:33:33'", createTimestampType(0), "TIMESTAMP '2018-10-28 03:33:33'") // time double in Vilnius + .addRoundTrip(format("Nullable(%s)", inputType), "NULL", createTimestampType(0), "CAST(NULL AS TIMESTAMP(0))"); + } + + protected SqlDataTypeTest unsupportedTimestampBecomeUnexpectedValueTest(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "'1969-12-31 23:59:59'", createTimestampType(0), "TIMESTAMP '1970-01-01 23:59:59'"); // unsupported timestamp become 1970-01-01 23:59:59 + } + + @Test(dataProvider = "clickHouseDateTimeMinMaxValuesDataProvider") + public void testClickHouseDateTimeMinMaxValues(String timestamp) + { + SqlDataTypeTest dateTests1 = SqlDataTypeTest.create() + .addRoundTrip("timestamp(0)", format("timestamp '%s'", timestamp), createTimestampType(0), format("TIMESTAMP '%s'", timestamp)); + SqlDataTypeTest dateTests2 = SqlDataTypeTest.create() + .addRoundTrip("timestamp", format("'%s'", timestamp), createTimestampType(0), format("TIMESTAMP '%s'", timestamp)); + SqlDataTypeTest dateTests3 = SqlDataTypeTest.create() + .addRoundTrip("datetime", format("'%s'", timestamp), createTimestampType(0), format("TIMESTAMP '%s'", timestamp)); + + for (Object[] timeZoneIds : sessionZonesDataProvider()) { + Session session = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(((ZoneId) timeZoneIds[0]).getId())) + .build(); + dateTests1 + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_timestamp")) + .execute(getQueryRunner(), session, trinoCreateAsSelect("test_timestamp")) + .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_timestamp")) + .execute(getQueryRunner(), session, trinoCreateAndInsert("test_timestamp")); + dateTests2.execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_timestamp")); + dateTests3.execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_datetime")); + } + } + + @DataProvider + public Object[][] clickHouseDateTimeMinMaxValuesDataProvider() + { + return new Object[][] { + {"1970-01-01 00:00:00"}, // min value in ClickHouse + {"2106-02-06 06:28:15"}, // max value in ClickHouse + }; + } + + @Test(dataProvider = "unsupportedTimestampDataProvider") + public void testUnsupportedTimestamp(String unsupportedTimestamp) + { + String minSupportedTimestamp = (String) clickHouseDateTimeMinMaxValuesDataProvider()[0][0]; + String maxSupportedTimestamp = (String) clickHouseDateTimeMinMaxValuesDataProvider()[1][0]; + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_unsupported_timestamp", "(dt timestamp(0))")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '%s')", table.getName(), unsupportedTimestamp), + format("Timestamp must be between %s and %s in ClickHouse: %s", minSupportedTimestamp, maxSupportedTimestamp, unsupportedTimestamp)); + } + + try (TestTable table = new TestTable(onRemoteDatabase(), "tpch.test_unsupported_timestamp", "(dt datetime) ENGINE=Log")) { + onRemoteDatabase().execute(format("INSERT INTO %s VALUES ('%s')", table.getName(), unsupportedTimestamp)); + assertQuery(format("SELECT dt <> TIMESTAMP '%s' FROM %s", unsupportedTimestamp, table.getName()), "SELECT true"); // Inserting an unsupported datetime in ClickHouse will turn it into another datetime + } + } + + @DataProvider + public Object[][] unsupportedTimestampDataProvider() + { + return new Object[][] { + {"1969-12-31 23:59:59"}, // min - 1 second + {"2106-02-06 06:28:16"}, // max + 1 second + }; + } + + @DataProvider + public Object[][] sessionZonesDataProvider() + { + return new Object[][] { + {UTC}, + {jvmZone}, + // using two non-JVM zones so that we don't need to worry what ClickHouse system zone is + {vilnius}, + {kathmandu}, + {ZoneId.of(TestingSession.DEFAULT_TIME_ZONE_KEY.getId())}, + }; + } + + @Test + public void testEnum() + { + SqlDataTypeTest.create() + .addRoundTrip("Enum('hello' = 1, 'world' = 2)", "'hello'", createUnboundedVarcharType(), "VARCHAR 'hello'") + .addRoundTrip("Enum('hello' = 1, 'world' = 2)", "'world'", createUnboundedVarcharType(), "VARCHAR 'world'") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_enum")); + } + + @Test + public void testUuid() + { + SqlDataTypeTest.create() + .addRoundTrip("Nullable(UUID)", "NULL", UuidType.UUID, "CAST(NULL AS UUID)") + .addRoundTrip("Nullable(UUID)", "'114514ea-0601-1981-1142-e9b55b0abd6d'", UuidType.UUID, "CAST('114514ea-0601-1981-1142-e9b55b0abd6d' AS UUID)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("default.ck_test_uuid")); + + SqlDataTypeTest.create() + .addRoundTrip("CAST(NULL AS UUID)", "cast(NULL as UUID)") + .addRoundTrip("UUID '114514ea-0601-1981-1142-e9b55b0abd6d'", "CAST('114514ea-0601-1981-1142-e9b55b0abd6d' AS UUID)") + .execute(getQueryRunner(), trinoCreateAsSelect("default.ck_test_uuid")) + .execute(getQueryRunner(), trinoCreateAndInsert("default.ck_test_uuid")); + } + + @Test + public void testIp() + { + SqlDataTypeTest.create() + .addRoundTrip("IPv4", "'0.0.0.0'", IPADDRESS, "IPADDRESS '0.0.0.0'") + .addRoundTrip("IPv4", "'116.253.40.133'", IPADDRESS, "IPADDRESS '116.253.40.133'") + .addRoundTrip("IPv4", "'255.255.255.255'", IPADDRESS, "IPADDRESS '255.255.255.255'") + .addRoundTrip("IPv6", "'::'", IPADDRESS, "IPADDRESS '::'") + .addRoundTrip("IPv6", "'2001:44c8:129:2632:33:0:252:2'", IPADDRESS, "IPADDRESS '2001:44c8:129:2632:33:0:252:2'") + .addRoundTrip("IPv6", "'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", IPADDRESS, "IPADDRESS 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'") + .addRoundTrip("Nullable(IPv4)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") + .addRoundTrip("Nullable(IPv6)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") + .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_ip")); + + SqlDataTypeTest.create() + .addRoundTrip("IPv4", "IPADDRESS '0.0.0.0'", IPADDRESS, "IPADDRESS '0.0.0.0'") + .addRoundTrip("IPv4", "IPADDRESS '116.253.40.133'", IPADDRESS, "IPADDRESS '116.253.40.133'") + .addRoundTrip("IPv4", "IPADDRESS '255.255.255.255'", IPADDRESS, "IPADDRESS '255.255.255.255'") + .addRoundTrip("IPv6", "IPADDRESS '::'", IPADDRESS, "IPADDRESS '::'") + .addRoundTrip("IPv6", "IPADDRESS '2001:44c8:129:2632:33:0:252:2'", IPADDRESS, "IPADDRESS '2001:44c8:129:2632:33:0:252:2'") + .addRoundTrip("IPv6", "IPADDRESS 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", IPADDRESS, "IPADDRESS 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'") + .addRoundTrip("Nullable(IPv4)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") + .addRoundTrip("Nullable(IPv6)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") + .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_ip")); + } + + protected static Session mapStringAsVarcharSession() + { + return testSessionBuilder() + .setCatalog("clickhouse") + .setSchema(TPCH_SCHEMA) + .setCatalogSessionProperty("clickhouse", "map_string_as_varchar", "true") + .build(); + } + + protected DataSetup trinoCreateAsSelect(String tableNamePrefix) + { + return trinoCreateAsSelect(getSession(), tableNamePrefix); + } + + protected DataSetup trinoCreateAsSelect(Session session, String tableNamePrefix) + { + return new CreateAsSelectDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + + protected DataSetup trinoCreateAndInsert(String tableNamePrefix) + { + return trinoCreateAndInsert(getSession(), tableNamePrefix); + } + + protected DataSetup trinoCreateAndInsert(Session session, String tableNamePrefix) + { + return new CreateAndInsertDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + + protected DataSetup clickhouseCreateAndInsert(String tableNamePrefix) + { + return new CreateAndInsertDataSetup(new ClickHouseSqlExecutor(onRemoteDatabase()), tableNamePrefix); + } + + protected DataSetup clickhouseCreateAndTrinoInsert(String tableNamePrefix) + { + return new CreateAndTrinoInsertDataSetup(new ClickHouseSqlExecutor(onRemoteDatabase()), new TrinoSqlExecutor(getQueryRunner()), tableNamePrefix); + } + + protected SqlExecutor onRemoteDatabase() + { + return clickhouseServer::execute; + } +} diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestConnectorTest.java index 45550e4b6168..47b0e22ff467 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestConnectorTest.java @@ -44,4 +44,25 @@ protected OptionalInt maxTableNameLength() // The numeric value depends on file system return OptionalInt.of(255 - ".sql.detached".length()); } + + @Override + protected String errorMessageForCreateTableAsSelectNegativeDate(String date) + { + // Override because the DateTime range was expanded in version 21.4 and later + return "Date must be between 1970-01-01 and 2149-06-06 in ClickHouse: " + date; + } + + @Override + protected String errorMessageForInsertNegativeDate(String date) + { + // Override because the DateTime range was expanded in version 21.4 and later + return "Date must be between 1970-01-01 and 2149-06-06 in ClickHouse: " + date; + } + + @Override + protected String errorMessageForDateYearOfEraPredicate(String date) + { + // Override because the DateTime range was expanded in version 21.4 and later + return "Date must be between 1970-01-01 and 2149-06-06 in ClickHouse: " + date; + } } diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java new file mode 100644 index 000000000000..5208050caaba --- /dev/null +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.clickhouse; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.testing.QueryRunner; +import io.trino.testing.datatype.SqlDataTypeTest; +import org.testng.annotations.DataProvider; + +import static io.trino.plugin.clickhouse.ClickHouseQueryRunner.createClickHouseQueryRunner; +import static io.trino.plugin.clickhouse.TestingClickHouseServer.CLICKHOUSE_LATEST_IMAGE; +import static io.trino.spi.type.TimestampType.createTimestampType; + +public class TestClickHouseLatestTypeMapping + extends BaseClickHouseTypeMapping +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + clickhouseServer = closeAfterClass(new TestingClickHouseServer(CLICKHOUSE_LATEST_IMAGE)); + return createClickHouseQueryRunner(clickhouseServer, ImmutableMap.of(), + ImmutableMap.builder() + .put("metadata.cache-ttl", "10m") + .put("metadata.cache-missing", "true") + .buildOrThrow(), + ImmutableList.of()); + } + + @DataProvider + @Override + public Object[][] clickHouseDateMinMaxValuesDataProvider() + { + // Override because the Date range was expanded in version 21.4 and later + return new Object[][] { + {"1970-01-01"}, // min value in ClickHouse + {"2149-06-06"}, // max value in ClickHouse + }; + } + + @DataProvider + @Override + public Object[][] unsupportedClickHouseDateValuesDataProvider() + { + // Override because the Date range was expanded in version 21.4 and later + return new Object[][] { + {"1969-12-31"}, // min - 1 day + {"2149-06-07"}, // max + 1 day + }; + } + + @Override + protected SqlDataTypeTest unsupportedTimestampBecomeUnexpectedValueTest(String inputType) + { + // Override because insert DateTime '1969-12-31 23:59:59' directly in ClickHouse will + // become '1970-01-01 00:00:00' in version 21.4 and later, however in versions prior + // to 21.4 the value will become '1970-01-01 23:59:59'. + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "'1969-12-31 23:59:59'", createTimestampType(0), "TIMESTAMP '1970-01-01 00:00:00'"); + } + + @DataProvider + @Override + public Object[][] clickHouseDateTimeMinMaxValuesDataProvider() + { + // Override because the DateTime range was expanded in version 21.4 and later + return new Object[][] { + {"1970-01-01 00:00:00"}, // min value in ClickHouse + {"2106-02-07 06:28:15"}, // max value in ClickHouse + }; + } + + @DataProvider + @Override + public Object[][] unsupportedTimestampDataProvider() + { + // Override because the DateTime range was expanded in version 21.4 and later + return new Object[][] { + {"1969-12-31 23:59:59"}, // min - 1 second + {"2106-02-07 06:28:16"}, // max + 1 second + }; + } +} diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseTypeMapping.java index cbbd379825fc..eaf4161bfe92 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseTypeMapping.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseTypeMapping.java @@ -15,92 +15,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.Session; -import io.trino.spi.type.TimeZoneKey; -import io.trino.spi.type.UuidType; -import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; -import io.trino.testing.TestingSession; -import io.trino.testing.datatype.CreateAndInsertDataSetup; -import io.trino.testing.datatype.CreateAndTrinoInsertDataSetup; -import io.trino.testing.datatype.CreateAsSelectDataSetup; -import io.trino.testing.datatype.DataSetup; -import io.trino.testing.datatype.SqlDataTypeTest; -import io.trino.testing.sql.TestTable; -import io.trino.testing.sql.TrinoSqlExecutor; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.ZoneId; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static io.trino.plugin.clickhouse.ClickHouseQueryRunner.TPCH_SCHEMA; import static io.trino.plugin.clickhouse.ClickHouseQueryRunner.createClickHouseQueryRunner; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DateType.DATE; -import static io.trino.spi.type.DecimalType.createDecimalType; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.RealType.REAL; -import static io.trino.spi.type.SmallintType.SMALLINT; -import static io.trino.spi.type.TimestampType.createTimestampType; -import static io.trino.spi.type.TinyintType.TINYINT; -import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; -import static io.trino.testing.TestingSession.testSessionBuilder; -import static io.trino.type.IpAddressType.IPADDRESS; -import static java.lang.String.format; -import static java.time.ZoneOffset.UTC; public class TestClickHouseTypeMapping - extends AbstractTestQueryFramework + extends BaseClickHouseTypeMapping { - private final ZoneId jvmZone = ZoneId.systemDefault(); - - // no DST in 1970, but has DST in later years (e.g. 2018) - private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); - - // minutes offset change since 1970-01-01, no DST - private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); - - private TestingClickHouseServer clickhouseServer; - - @BeforeClass - public void setUp() - { - checkState(jvmZone.getId().equals("America/Bahia_Banderas"), "This test assumes certain JVM time zone"); - LocalDate dateOfLocalTimeChangeForwardAtMidnightInJvmZone = LocalDate.of(1970, 1, 1); - checkIsGap(jvmZone, dateOfLocalTimeChangeForwardAtMidnightInJvmZone.atStartOfDay()); - - LocalDate dateOfLocalTimeChangeForwardAtMidnightInSomeZone = LocalDate.of(1983, 4, 1); - checkIsGap(vilnius, dateOfLocalTimeChangeForwardAtMidnightInSomeZone.atStartOfDay()); - LocalDate dateOfLocalTimeChangeBackwardAtMidnightInSomeZone = LocalDate.of(1983, 10, 1); - checkIsDoubled(vilnius, dateOfLocalTimeChangeBackwardAtMidnightInSomeZone.atStartOfDay().minusMinutes(1)); - - LocalDate timeGapInKathmandu = LocalDate.of(1986, 1, 1); - checkIsGap(kathmandu, timeGapInKathmandu.atStartOfDay()); - } - - private static void checkIsGap(ZoneId zone, LocalDateTime dateTime) - { - verify(isGap(zone, dateTime), "Expected %s to be a gap in %s", dateTime, zone); - } - - private static boolean isGap(ZoneId zone, LocalDateTime dateTime) - { - return zone.getRules().getValidOffsets(dateTime).isEmpty(); - } - - private static void checkIsDoubled(ZoneId zone, LocalDateTime dateTime) - { - verify(zone.getRules().getValidOffsets(dateTime).size() == 2, "Expected %s to be doubled in %s", dateTime, zone); - } - @Override protected QueryRunner createQueryRunner() throws Exception @@ -113,704 +34,4 @@ protected QueryRunner createQueryRunner() .buildOrThrow(), ImmutableList.of()); } - - @Test - public void testTinyint() - { - SqlDataTypeTest.create() - .addRoundTrip("tinyint", "-128", TINYINT, "TINYINT '-128'") // min value in ClickHouse and Trino - .addRoundTrip("tinyint", "5", TINYINT, "TINYINT '5'") - .addRoundTrip("tinyint", "127", TINYINT, "TINYINT '127'") // max value in ClickHouse and Trino - .execute(getQueryRunner(), trinoCreateAsSelect("test_tinyint")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_tinyint")) - - .addRoundTrip("Nullable(tinyint)", "NULL", TINYINT, "CAST(NULL AS TINYINT)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_tinyint")); - - SqlDataTypeTest.create() - .addRoundTrip("tinyint", "NULL", TINYINT, "CAST(NULL AS TINYINT)") - .execute(getQueryRunner(), trinoCreateAsSelect("test_tinyint")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_tinyint")); - } - - @Test - public void testUnsupportedTinyint() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("tinyint", "-129", TINYINT, "TINYINT '127'") - .addRoundTrip("tinyint", "128", TINYINT, "TINYINT '-128'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_tinyint")); - } - - @Test - public void testSmallint() - { - SqlDataTypeTest.create() - .addRoundTrip("smallint", "-32768", SMALLINT, "SMALLINT '-32768'") // min value in ClickHouse and Trino - .addRoundTrip("smallint", "32456", SMALLINT, "SMALLINT '32456'") - .addRoundTrip("smallint", "32767", SMALLINT, "SMALLINT '32767'") // max value in ClickHouse and Trino - .execute(getQueryRunner(), trinoCreateAsSelect("test_smallint")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_smallint")) - - .addRoundTrip("Nullable(smallint)", "NULL", SMALLINT, "CAST(NULL AS SMALLINT)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_smallint")); - - SqlDataTypeTest.create() - .addRoundTrip("smallint", "NULL", SMALLINT, "CAST(NULL AS SMALLINT)") - .execute(getQueryRunner(), trinoCreateAsSelect("test_smallint")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_smallint")); - } - - @Test - public void testUnsupportedSmallint() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("smallint", "-32769", SMALLINT, "SMALLINT '32767'") - .addRoundTrip("smallint", "32768", SMALLINT, "SMALLINT '-32768'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_smallint")); - } - - @Test - public void testInteger() - { - SqlDataTypeTest.create() - .addRoundTrip("integer", "-2147483648", INTEGER, "-2147483648") // min value in ClickHouse and Trino - .addRoundTrip("integer", "1234567890", INTEGER, "1234567890") - .addRoundTrip("integer", "2147483647", INTEGER, "2147483647") // max value in ClickHouse and Trino - .execute(getQueryRunner(), trinoCreateAsSelect("test_int")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_int")) - - .addRoundTrip("Nullable(integer)", "NULL", INTEGER, "CAST(NULL AS INTEGER)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_int")); - - SqlDataTypeTest.create() - .addRoundTrip("integer", "NULL", INTEGER, "CAST(NULL AS INTEGER)") - .execute(getQueryRunner(), trinoCreateAsSelect("test_int")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_int")); - } - - @Test - public void testUnsupportedInteger() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("integer", "-2147483649", INTEGER, "INTEGER '2147483647'") - .addRoundTrip("integer", "2147483648", INTEGER, "INTEGER '-2147483648'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_integer")); - } - - @Test - public void testBigint() - { - SqlDataTypeTest.create() - .addRoundTrip("bigint", "-9223372036854775808", BIGINT, "-9223372036854775808") // min value in ClickHouse and Trino - .addRoundTrip("bigint", "123456789012", BIGINT, "123456789012") - .addRoundTrip("bigint", "9223372036854775807", BIGINT, "9223372036854775807") // max value in ClickHouse and Trino - .execute(getQueryRunner(), trinoCreateAsSelect("test_bigint")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_bigint")) - - .addRoundTrip("Nullable(bigint)", "NULL", BIGINT, "CAST(NULL AS BIGINT)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_bigint")); - - SqlDataTypeTest.create() - .addRoundTrip("bigint", "NULL", BIGINT, "CAST(NULL AS BIGINT)") - .execute(getQueryRunner(), trinoCreateAsSelect("test_bigint")) - .execute(getQueryRunner(), trinoCreateAndInsert("test_bigint")); - } - - @Test - public void testUnsupportedBigint() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("bigint", "-9223372036854775809", BIGINT, "BIGINT '9223372036854775807'") - .addRoundTrip("bigint", "9223372036854775808", BIGINT, "BIGINT '-9223372036854775808'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_bigint")); - } - - @Test - public void testUint8() - { - SqlDataTypeTest.create() - .addRoundTrip("UInt8", "0", SMALLINT, "SMALLINT '0'") // min value in ClickHouse - .addRoundTrip("UInt8", "255", SMALLINT, "SMALLINT '255'") // max value in ClickHouse - .addRoundTrip("Nullable(UInt8)", "NULL", SMALLINT, "CAST(null AS SMALLINT)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint8")); - - SqlDataTypeTest.create() - .addRoundTrip("UInt8", "0", SMALLINT, "SMALLINT '0'") // min value in ClickHouse - .addRoundTrip("UInt8", "255", SMALLINT, "SMALLINT '255'") // max value in ClickHouse - .addRoundTrip("Nullable(UInt8)", "NULL", SMALLINT, "CAST(null AS SMALLINT)") - .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint8")); - } - - @Test - public void testUnsupportedUint8() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("UInt8", "-1", SMALLINT, "SMALLINT '255'") - .addRoundTrip("UInt8", "256", SMALLINT, "SMALLINT '0'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint8")); - - // Prevent writing incorrect results in the connector - try (TestTable table = new TestTable(clickhouseServer::execute, "tpch.test_unsupported_uint8", "(value UInt8) ENGINE=Log")) { - assertQueryFails( - format("INSERT INTO %s VALUES (-1)", table.getName()), - "Value must be between 0 and 255 in ClickHouse: -1"); - assertQueryFails( - format("INSERT INTO %s VALUES (256)", table.getName()), - "Value must be between 0 and 255 in ClickHouse: 256"); - } - } - - @Test - public void testUint16() - { - SqlDataTypeTest.create() - .addRoundTrip("UInt16", "0", INTEGER, "0") // min value in ClickHouse - .addRoundTrip("UInt16", "65535", INTEGER, "65535") // max value in ClickHouse - .addRoundTrip("Nullable(UInt16)", "NULL", INTEGER, "CAST(null AS INTEGER)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint16")); - - SqlDataTypeTest.create() - .addRoundTrip("UInt16", "0", INTEGER, "0") // min value in ClickHouse - .addRoundTrip("UInt16", "65535", INTEGER, "65535") // max value in ClickHouse - .addRoundTrip("Nullable(UInt16)", "NULL", INTEGER, "CAST(null AS INTEGER)") - .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint16")); - } - - @Test - public void testUnsupportedUint16() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("UInt16", "-1", INTEGER, "65535") - .addRoundTrip("UInt16", "65536", INTEGER, "0") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint16")); - - // Prevent writing incorrect results in the connector - try (TestTable table = new TestTable(clickhouseServer::execute, "tpch.test_unsupported_uint16", "(value UInt16) ENGINE=Log")) { - assertQueryFails( - format("INSERT INTO %s VALUES (-1)", table.getName()), - "Value must be between 0 and 65535 in ClickHouse: -1"); - assertQueryFails( - format("INSERT INTO %s VALUES (65536)", table.getName()), - "Value must be between 0 and 65535 in ClickHouse: 65536"); - } - } - - @Test - public void testUint32() - { - SqlDataTypeTest.create() - .addRoundTrip("UInt32", "0", BIGINT, "BIGINT '0'") // min value in ClickHouse - .addRoundTrip("UInt32", "4294967295", BIGINT, "BIGINT '4294967295'") // max value in ClickHouse - .addRoundTrip("Nullable(UInt32)", "NULL", BIGINT, "CAST(null AS BIGINT)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint32")); - - SqlDataTypeTest.create() - .addRoundTrip("UInt32", "BIGINT '0'", BIGINT, "BIGINT '0'") // min value in ClickHouse - .addRoundTrip("UInt32", "BIGINT '4294967295'", BIGINT, "BIGINT '4294967295'") // max value in ClickHouse - .addRoundTrip("Nullable(UInt32)", "NULL", BIGINT, "CAST(null AS BIGINT)") - .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint32")); - } - - @Test - public void testUnsupportedUint32() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("UInt32", "-1", BIGINT, "BIGINT '4294967295'") - .addRoundTrip("UInt32", "4294967296", BIGINT, "BIGINT '0'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint32")); - - // Prevent writing incorrect results in the connector - try (TestTable table = new TestTable(clickhouseServer::execute, "tpch.test_unsupported_uint32", "(value UInt32) ENGINE=Log")) { - assertQueryFails( - format("INSERT INTO %s VALUES (CAST('-1' AS BIGINT))", table.getName()), - "Value must be between 0 and 4294967295 in ClickHouse: -1"); - assertQueryFails( - format("INSERT INTO %s VALUES (CAST('4294967296' AS BIGINT))", table.getName()), - "Value must be between 0 and 4294967295 in ClickHouse: 4294967296"); - } - } - - @Test - public void testUint64() - { - SqlDataTypeTest.create() - .addRoundTrip("UInt64", "0", createDecimalType(20), "CAST('0' AS decimal(20, 0))") // min value in ClickHouse - .addRoundTrip("UInt64", "18446744073709551615", createDecimalType(20), "CAST('18446744073709551615' AS decimal(20, 0))") // max value in ClickHouse - .addRoundTrip("Nullable(UInt64)", "NULL", createDecimalType(20), "CAST(null AS decimal(20, 0))") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_uint64")); - - SqlDataTypeTest.create() - .addRoundTrip("UInt64", "CAST('0' AS decimal(20, 0))", createDecimalType(20), "CAST('0' AS decimal(20, 0))") // min value in ClickHouse - .addRoundTrip("UInt64", "CAST('18446744073709551615' AS decimal(20, 0))", createDecimalType(20), "CAST('18446744073709551615' AS decimal(20, 0))") // max value in ClickHouse - .addRoundTrip("Nullable(UInt64)", "NULL", createDecimalType(20), "CAST(null AS decimal(20, 0))") - .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_uint64")); - } - - @Test - public void testUnsupportedUint64() - { - // ClickHouse stores incorrect results when the values are out of supported range. This test should be fixed when ClickHouse changes the behavior. - SqlDataTypeTest.create() - .addRoundTrip("UInt64", "-1", createDecimalType(20), "CAST('18446744073709551615' AS decimal(20, 0))") - .addRoundTrip("UInt64", "18446744073709551616", createDecimalType(20), "CAST('0' AS decimal(20, 0))") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_unsupported_uint64")); - - // Prevent writing incorrect results in the connector - try (TestTable table = new TestTable(clickhouseServer::execute, "tpch.test_unsupported_uint64", "(value UInt64) ENGINE=Log")) { - assertQueryFails( - format("INSERT INTO %s VALUES (CAST('-1' AS decimal(20, 0)))", table.getName()), - "Value must be between 0 and 18446744073709551615 in ClickHouse: -1"); - assertQueryFails( - format("INSERT INTO %s VALUES (CAST('18446744073709551616' AS decimal(20, 0)))", table.getName()), - "Value must be between 0 and 18446744073709551615 in ClickHouse: 18446744073709551616"); - } - } - - @Test - public void testReal() - { - SqlDataTypeTest.create() - .addRoundTrip("real", "12.5", REAL, "REAL '12.5'") - .addRoundTrip("real", "nan()", REAL, "CAST(nan() AS REAL)") - .addRoundTrip("real", "-infinity()", REAL, "CAST(-infinity() AS REAL)") - .addRoundTrip("real", "+infinity()", REAL, "CAST(+infinity() AS REAL)") - .addRoundTrip("real", "NULL", REAL, "CAST(NULL AS REAL)") - .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_real")); - - SqlDataTypeTest.create() - .addRoundTrip("real", "12.5", REAL, "REAL '12.5'") - .addRoundTrip("real", "nan", REAL, "CAST(nan() AS REAL)") - .addRoundTrip("real", "-inf", REAL, "CAST(-infinity() AS REAL)") - .addRoundTrip("real", "+inf", REAL, "CAST(+infinity() AS REAL)") - .addRoundTrip("Nullable(real)", "NULL", REAL, "CAST(NULL AS REAL)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_real")); - } - - @Test - public void testDouble() - { - SqlDataTypeTest.create() - .addRoundTrip("double", "3.1415926835", DOUBLE, "DOUBLE '3.1415926835'") - .addRoundTrip("double", "1.79769E308", DOUBLE, "DOUBLE '1.79769E308'") - .addRoundTrip("double", "2.225E-307", DOUBLE, "DOUBLE '2.225E-307'") - - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_double")) - - .addRoundTrip("double", "NULL", DOUBLE, "CAST(NULL AS DOUBLE)") - - .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_double")); - - SqlDataTypeTest.create() - .addRoundTrip("Nullable(double)", "NULL", DOUBLE, "CAST(NULL AS DOUBLE)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.trino_test_nullable_double")); - } - - @Test - public void testDecimal() - { - SqlDataTypeTest.create() - .addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))") - .addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))") - .addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))") - .addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))") - .addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))") - .addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))") - .addRoundTrip("decimal(4, 2)", "CAST('2' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2' AS decimal(4, 2))") - .addRoundTrip("decimal(4, 2)", "CAST('2.3' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2.3' AS decimal(4, 2))") - .addRoundTrip("decimal(24, 2)", "CAST('2' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2' AS decimal(24, 2))") - .addRoundTrip("decimal(24, 2)", "CAST('2.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2.3' AS decimal(24, 2))") - .addRoundTrip("decimal(24, 2)", "CAST('123456789.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('123456789.3' AS decimal(24, 2))") - .addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))") - .addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))") - .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") - .addRoundTrip("decimal(38, 0)", "CAST('27182818284590452353602874713526624977' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('27182818284590452353602874713526624977' AS decimal(38, 0))") - .addRoundTrip("decimal(38, 0)", "CAST('-27182818284590452353602874713526624977' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('-27182818284590452353602874713526624977' AS decimal(38, 0))") - - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_decimal")) - - .addRoundTrip("decimal(3, 1)", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") - .addRoundTrip("decimal(30, 5)", "NULL", createDecimalType(30, 5), "CAST(NULL AS decimal(30,5))") - - .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); - - SqlDataTypeTest.create() - .addRoundTrip("Nullable(decimal(3, 1))", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") - .addRoundTrip("Nullable(decimal(30, 5))", "NULL", createDecimalType(30, 5), "CAST(NULL AS decimal(30,5))") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_nullable_decimal")); - } - - @Test - public void testClickHouseChar() - { - // ClickHouse char is FixedString, which is arbitrary bytes - SqlDataTypeTest.create() - // plain - .addRoundTrip("char(10)", "'text_a'", VARBINARY, "to_utf8('text_a')") - .addRoundTrip("char(255)", "'text_b'", VARBINARY, "to_utf8('text_b')") - .addRoundTrip("char(5)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") - .addRoundTrip("char(32)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") - .addRoundTrip("char(1)", "'😂'", VARBINARY, "to_utf8('😂')") - .addRoundTrip("char(77)", "'Ну, погоди!'", VARBINARY, "to_utf8('Ну, погоди!')") - // nullable - .addRoundTrip("Nullable(char(10))", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("Nullable(char(10))", "'text_a'", VARBINARY, "to_utf8('text_a')") - .addRoundTrip("Nullable(char(1))", "'😂'", VARBINARY, "to_utf8('😂')") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_char")); - - // Set map_string_as_varchar session property as true - SqlDataTypeTest.create() - // plain - .addRoundTrip("char(10)", "'text_a'", VARCHAR, "CAST('text_a' AS varchar)") - .addRoundTrip("char(255)", "'text_b'", VARCHAR, "CAST('text_b' AS varchar)") - .addRoundTrip("char(5)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") - .addRoundTrip("char(32)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") - .addRoundTrip("char(1)", "'😂'", VARCHAR, "CAST('😂' AS varchar)") - .addRoundTrip("char(77)", "'Ну, погоди!'", VARCHAR, "CAST('Ну, погоди!' AS varchar)") - // nullable - .addRoundTrip("Nullable(char(10))", "NULL", VARCHAR, "CAST(NULL AS varchar)") - .addRoundTrip("Nullable(char(10))", "'text_a'", VARCHAR, "CAST('text_a' AS varchar)") - .addRoundTrip("Nullable(char(1))", "'😂'", VARCHAR, "CAST('😂' AS varchar)") - .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_char")); - } - - @Test - public void testClickHouseFixedString() - { - SqlDataTypeTest.create() - // plain - .addRoundTrip("FixedString(10)", "'c12345678b'", VARBINARY, "to_utf8('c12345678b')") - .addRoundTrip("FixedString(10)", "'c123'", VARBINARY, "to_utf8('c123\0\0\0\0\0\0')") - // nullable - .addRoundTrip("Nullable(FixedString(10))", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("Nullable(FixedString(10))", "'c12345678b'", VARBINARY, "to_utf8('c12345678b')") - .addRoundTrip("Nullable(FixedString(10))", "'c123'", VARBINARY, "to_utf8('c123\0\0\0\0\0\0')") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_fixed_string")); - - // Set map_string_as_varchar session property as true - SqlDataTypeTest.create() - // plain - .addRoundTrip("FixedString(10)", "'c12345678b'", VARCHAR, "CAST('c12345678b' AS varchar)") - .addRoundTrip("FixedString(10)", "'c123'", VARCHAR, "CAST('c123\0\0\0\0\0\0' AS varchar)") - // nullable - .addRoundTrip("Nullable(FixedString(10))", "NULL", VARCHAR, "CAST(NULL AS varchar)") - .addRoundTrip("Nullable(FixedString(10))", "'c12345678b'", VARCHAR, "CAST('c12345678b' AS varchar)") - .addRoundTrip("Nullable(FixedString(10))", "'c123'", VARCHAR, "CAST('c123\0\0\0\0\0\0' AS varchar)") - .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_fixed_string")); - } - - @Test - public void testTrinoChar() - { - SqlDataTypeTest.create() - .addRoundTrip("char(10)", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("char(10)", "'text_a'", VARBINARY, "to_utf8('text_a')") - .addRoundTrip("char(255)", "'text_b'", VARBINARY, "to_utf8('text_b')") - .addRoundTrip("char(5)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") - .addRoundTrip("char(32)", "'攻殻機動隊'", VARBINARY, "to_utf8('攻殻機動隊')") - .addRoundTrip("char(1)", "'😂'", VARBINARY, "to_utf8('😂')") - .addRoundTrip("char(77)", "'Ну, погоди!'", VARBINARY, "to_utf8('Ну, погоди!')") - .execute(getQueryRunner(), trinoCreateAsSelect("test_char")) - .execute(getQueryRunner(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_char")); - - // Set map_string_as_varchar session property as true - SqlDataTypeTest.create() - .addRoundTrip("char(10)", "NULL", VARCHAR, "CAST(NULL AS varchar)") - .addRoundTrip("char(10)", "'text_a'", VARCHAR, "CAST('text_a' AS varchar)") - .addRoundTrip("char(255)", "'text_b'", VARCHAR, "CAST('text_b' AS varchar)") - .addRoundTrip("char(5)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") - .addRoundTrip("char(32)", "'攻殻機動隊'", VARCHAR, "CAST('攻殻機動隊' AS varchar)") - .addRoundTrip("char(1)", "'😂'", VARCHAR, "CAST('😂' AS varchar)") - .addRoundTrip("char(77)", "'Ну, погоди!'", VARCHAR, "CAST('Ну, погоди!' AS varchar)") - .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect("test_char")) - .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_char")); - } - - @Test - public void testClickHouseVarchar() - { - // TODO add more test cases - // ClickHouse varchar is String, which is arbitrary bytes - SqlDataTypeTest.create() - // plain - .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") - // nullable - .addRoundTrip("Nullable(varchar(30))", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("Nullable(varchar(30))", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_varchar")); - - // Set map_string_as_varchar session property as true - SqlDataTypeTest.create() - // plain - .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") - // nullable - .addRoundTrip("Nullable(varchar(30))", "NULL", VARCHAR, "CAST(NULL AS varchar)") - .addRoundTrip("Nullable(varchar(30))", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") - .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_varchar")); - } - - @Test - public void testClickHouseString() - { - // TODO add more test cases - SqlDataTypeTest.create() - // plain - .addRoundTrip("String", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") - // nullable - .addRoundTrip("Nullable(String)", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("Nullable(String)", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_varchar")); - - // Set map_string_as_varchar session property as true - SqlDataTypeTest.create() - // plain - .addRoundTrip("String", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") - // nullable - .addRoundTrip("Nullable(String)", "NULL", VARCHAR, "CAST(NULL AS varchar)") - .addRoundTrip("Nullable(String)", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") - .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_varchar")); - } - - @Test - public void testTrinoVarchar() - { - SqlDataTypeTest.create() - .addRoundTrip("varchar(30)", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") - .execute(getQueryRunner(), trinoCreateAsSelect("test_varchar")) - .execute(getQueryRunner(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_varchar")); - - // Set map_string_as_varchar session property as true - SqlDataTypeTest.create() - .addRoundTrip("varchar(30)", "NULL", VARCHAR, "CAST(NULL AS varchar)") - .addRoundTrip("varchar(30)", "'Piękna łąka w 東京都'", VARCHAR, "CAST('Piękna łąka w 東京都' AS varchar)") - .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect("test_varchar")) - .execute(getQueryRunner(), mapStringAsVarcharSession(), trinoCreateAsSelect(mapStringAsVarcharSession(), "test_varchar")); - } - - @Test - public void testTrinoVarbinary() - { - SqlDataTypeTest.create() - .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") - .addRoundTrip("varbinary", "X''", VARBINARY, "X''") - .addRoundTrip("varbinary", "X'68656C6C6F'", VARBINARY, "to_utf8('hello')") - .addRoundTrip("varbinary", "X'5069C4996B6E6120C582C4856B61207720E69DB1E4BAACE983BD'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") - .addRoundTrip("varbinary", "X'4261672066756C6C206F6620F09F92B0'", VARBINARY, "to_utf8('Bag full of 💰')") - .addRoundTrip("varbinary", "X'0001020304050607080DF9367AA7000000'", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text - .addRoundTrip("varbinary", "X'000000000000'", VARBINARY, "X'000000000000'") - .execute(getQueryRunner(), trinoCreateAsSelect("test_varbinary")); - } - - @Test(dataProvider = "sessionZonesDataProvider") - public void testDate(ZoneId sessionZone) - { - Session session = Session.builder(getSession()) - .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(sessionZone.getId())) - .build(); - SqlDataTypeTest.create() - .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") // min value in ClickHouse - .addRoundTrip("date", "DATE '1970-02-03'", DATE, "DATE '1970-02-03'") - .addRoundTrip("date", "DATE '2017-07-01'", DATE, "DATE '2017-07-01'") // summer on northern hemisphere (possible DST) - .addRoundTrip("date", "DATE '2017-01-01'", DATE, "DATE '2017-01-01'") // winter on northern hemisphere (possible DST on southern hemisphere) - .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") - .addRoundTrip("date", "DATE '1983-04-01'", DATE, "DATE '1983-04-01'") - .addRoundTrip("date", "DATE '1983-10-01'", DATE, "DATE '1983-10-01'") - .addRoundTrip("date", "DATE '2106-02-07'", DATE, "DATE '2106-02-07'") // max value in ClickHouse - .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_date")) - .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) - .execute(getQueryRunner(), session, trinoCreateAsSelect("test_date")) - .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_date")) - .execute(getQueryRunner(), session, trinoCreateAndInsert("test_date")); - - // Null - SqlDataTypeTest.create() - .addRoundTrip("date", "NULL", DATE, "CAST(NULL AS DATE)") - .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) - .execute(getQueryRunner(), session, trinoCreateAsSelect("test_date")) - .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_date")) - .execute(getQueryRunner(), session, trinoCreateAndInsert("test_date")); - SqlDataTypeTest.create() - .addRoundTrip("Nullable(date)", "NULL", DATE, "CAST(NULL AS DATE)") - .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_date")); - } - - @Test - public void testUnsupportedDate() - { - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_unsupported_date", "(dt date)")) { - assertQueryFails( - format("INSERT INTO %s VALUES (DATE '1969-12-31')", table.getName()), - "Date must be between 1970-01-01 and 2106-02-07 in ClickHouse: 1969-12-31"); - assertQueryFails( - format("INSERT INTO %s VALUES (DATE '2106-02-08')", table.getName()), - "Date must be between 1970-01-01 and 2106-02-07 in ClickHouse: 2106-02-08"); - } - } - - @Test(dataProvider = "sessionZonesDataProvider") - public void testTimestamp(ZoneId sessionZone) - { - Session session = Session.builder(getSession()) - .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(sessionZone.getId())) - .build(); - - SqlDataTypeTest.create() - .addRoundTrip("timestamp(0)", "timestamp '1970-01-01 00:00:00'", createTimestampType(0), "TIMESTAMP '1970-01-01 00:00:00'") // min value in ClickHouse - .addRoundTrip("timestamp(0)", "timestamp '1986-01-01 00:13:07'", createTimestampType(0), "TIMESTAMP '1986-01-01 00:13:07'") // time gap in Kathmandu - .addRoundTrip("timestamp(0)", "timestamp '2018-03-25 03:17:17'", createTimestampType(0), "TIMESTAMP '2018-03-25 03:17:17'") // time gap in Vilnius - .addRoundTrip("timestamp(0)", "timestamp '2018-10-28 01:33:17'", createTimestampType(0), "TIMESTAMP '2018-10-28 01:33:17'") // time doubled in JVM zone - .addRoundTrip("timestamp(0)", "timestamp '2018-10-28 03:33:33'", createTimestampType(0), "TIMESTAMP '2018-10-28 03:33:33'") // time double in Vilnius - .addRoundTrip("timestamp(0)", "timestamp '2105-12-31 23:59:59'", createTimestampType(0), "TIMESTAMP '2105-12-31 23:59:59'") // max value in ClickHouse - .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_timestamp")) - .execute(getQueryRunner(), session, trinoCreateAsSelect("test_timestamp")) - .execute(getQueryRunner(), session, trinoCreateAndInsert(session, "test_timestamp")) - .execute(getQueryRunner(), session, trinoCreateAndInsert("test_timestamp")); - - addTimestampRoundTrips("timestamp") - .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_timestamp")); - addTimestampRoundTrips("datetime") - .execute(getQueryRunner(), session, clickhouseCreateAndInsert("tpch.test_datetime")); - } - - private SqlDataTypeTest addTimestampRoundTrips(String inputType) - { - return SqlDataTypeTest.create() - .addRoundTrip(inputType, "'1969-12-31 23:59:59'", createTimestampType(0), "TIMESTAMP '1970-01-01 23:59:59'") // unsupported timestamp become 1970-01-01 23:59:59 - .addRoundTrip(inputType, "'1970-01-01 00:00:00'", createTimestampType(0), "TIMESTAMP '1970-01-01 00:00:00'") // min value in ClickHouse - .addRoundTrip(inputType, "'1986-01-01 00:13:07'", createTimestampType(0), "TIMESTAMP '1986-01-01 00:13:07'") // time gap in Kathmandu - .addRoundTrip(inputType, "'2018-03-25 03:17:17'", createTimestampType(0), "TIMESTAMP '2018-03-25 03:17:17'") // time gap in Vilnius - .addRoundTrip(inputType, "'2018-10-28 01:33:17'", createTimestampType(0), "TIMESTAMP '2018-10-28 01:33:17'") // time doubled in JVM zone - .addRoundTrip(inputType, "'2018-10-28 03:33:33'", createTimestampType(0), "TIMESTAMP '2018-10-28 03:33:33'") // time double in Vilnius - .addRoundTrip(inputType, "'2105-12-31 23:59:59'", createTimestampType(0), "TIMESTAMP '2105-12-31 23:59:59'") // max value in ClickHouse - .addRoundTrip(format("Nullable(%s)", inputType), "NULL", createTimestampType(0), "CAST(NULL AS TIMESTAMP(0))"); - } - - @Test - public void testUnsupportedTimestamp() - { - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_unsupported_timestamp", "(dt timestamp(0))")) { - assertQueryFails( - format("INSERT INTO %s VALUES (TIMESTAMP '-9999-12-31 23:59:59')", table.getName()), - "Timestamp must be between 1970-01-01 00:00:00 and 2105-12-31 23:59:59 in ClickHouse: -9999-12-31 23:59:59"); - assertQueryFails( - format("INSERT INTO %s VALUES (TIMESTAMP '1969-12-31 23:59:59')", table.getName()), - "Timestamp must be between 1970-01-01 00:00:00 and 2105-12-31 23:59:59 in ClickHouse: 1969-12-31 23:59:59"); - assertQueryFails( - format("INSERT INTO %s VALUES (TIMESTAMP '2106-01-01 00:00:00')", table.getName()), - "Timestamp must be between 1970-01-01 00:00:00 and 2105-12-31 23:59:59 in ClickHouse: 2106-01-01 00:00:00"); - assertQueryFails( - format("INSERT INTO %s VALUES (TIMESTAMP '9999-12-31 23:59:59')", table.getName()), - "Timestamp must be between 1970-01-01 00:00:00 and 2105-12-31 23:59:59 in ClickHouse: 9999-12-31 23:59:59"); - } - } - - @DataProvider - public Object[][] sessionZonesDataProvider() - { - return new Object[][] { - {UTC}, - {jvmZone}, - // using two non-JVM zones so that we don't need to worry what ClickHouse system zone is - {vilnius}, - {kathmandu}, - {ZoneId.of(TestingSession.DEFAULT_TIME_ZONE_KEY.getId())}, - }; - } - - @Test - public void testEnum() - { - SqlDataTypeTest.create() - .addRoundTrip("Enum('hello' = 1, 'world' = 2)", "'hello'", createUnboundedVarcharType(), "VARCHAR 'hello'") - .addRoundTrip("Enum('hello' = 1, 'world' = 2)", "'world'", createUnboundedVarcharType(), "VARCHAR 'world'") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_enum")); - } - - @Test - public void testUuid() - { - SqlDataTypeTest.create() - .addRoundTrip("Nullable(UUID)", "NULL", UuidType.UUID, "CAST(NULL AS UUID)") - .addRoundTrip("Nullable(UUID)", "'114514ea-0601-1981-1142-e9b55b0abd6d'", UuidType.UUID, "CAST('114514ea-0601-1981-1142-e9b55b0abd6d' AS UUID)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("default.ck_test_uuid")); - - SqlDataTypeTest.create() - .addRoundTrip("CAST(NULL AS UUID)", "cast(NULL as UUID)") - .addRoundTrip("UUID '114514ea-0601-1981-1142-e9b55b0abd6d'", "CAST('114514ea-0601-1981-1142-e9b55b0abd6d' AS UUID)") - .execute(getQueryRunner(), trinoCreateAsSelect("default.ck_test_uuid")) - .execute(getQueryRunner(), trinoCreateAndInsert("default.ck_test_uuid")); - } - - @Test - public void testIp() - { - SqlDataTypeTest.create() - .addRoundTrip("IPv4", "'0.0.0.0'", IPADDRESS, "IPADDRESS '0.0.0.0'") - .addRoundTrip("IPv4", "'116.253.40.133'", IPADDRESS, "IPADDRESS '116.253.40.133'") - .addRoundTrip("IPv4", "'255.255.255.255'", IPADDRESS, "IPADDRESS '255.255.255.255'") - .addRoundTrip("IPv6", "'::'", IPADDRESS, "IPADDRESS '::'") - .addRoundTrip("IPv6", "'2001:44c8:129:2632:33:0:252:2'", IPADDRESS, "IPADDRESS '2001:44c8:129:2632:33:0:252:2'") - .addRoundTrip("IPv6", "'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", IPADDRESS, "IPADDRESS 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'") - .addRoundTrip("Nullable(IPv4)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") - .addRoundTrip("Nullable(IPv6)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") - .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_ip")); - - SqlDataTypeTest.create() - .addRoundTrip("IPv4", "IPADDRESS '0.0.0.0'", IPADDRESS, "IPADDRESS '0.0.0.0'") - .addRoundTrip("IPv4", "IPADDRESS '116.253.40.133'", IPADDRESS, "IPADDRESS '116.253.40.133'") - .addRoundTrip("IPv4", "IPADDRESS '255.255.255.255'", IPADDRESS, "IPADDRESS '255.255.255.255'") - .addRoundTrip("IPv6", "IPADDRESS '::'", IPADDRESS, "IPADDRESS '::'") - .addRoundTrip("IPv6", "IPADDRESS '2001:44c8:129:2632:33:0:252:2'", IPADDRESS, "IPADDRESS '2001:44c8:129:2632:33:0:252:2'") - .addRoundTrip("IPv6", "IPADDRESS 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", IPADDRESS, "IPADDRESS 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'") - .addRoundTrip("Nullable(IPv4)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") - .addRoundTrip("Nullable(IPv6)", "NULL", IPADDRESS, "CAST(NULL AS IPADDRESS)") - .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_ip")); - } - - private static Session mapStringAsVarcharSession() - { - return testSessionBuilder() - .setCatalog("clickhouse") - .setSchema(TPCH_SCHEMA) - .setCatalogSessionProperty("clickhouse", "map_string_as_varchar", "true") - .build(); - } - - private DataSetup trinoCreateAsSelect(String tableNamePrefix) - { - return trinoCreateAsSelect(getSession(), tableNamePrefix); - } - - private DataSetup trinoCreateAsSelect(Session session, String tableNamePrefix) - { - return new CreateAsSelectDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); - } - - private DataSetup trinoCreateAndInsert(String tableNamePrefix) - { - return trinoCreateAndInsert(getSession(), tableNamePrefix); - } - - private DataSetup trinoCreateAndInsert(Session session, String tableNamePrefix) - { - return new CreateAndInsertDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); - } - - private DataSetup clickhouseCreateAndInsert(String tableNamePrefix) - { - return new CreateAndInsertDataSetup(new ClickHouseSqlExecutor(clickhouseServer::execute), tableNamePrefix); - } - - private DataSetup clickhouseCreateAndTrinoInsert(String tableNamePrefix) - { - return new CreateAndTrinoInsertDataSetup(new ClickHouseSqlExecutor(clickhouseServer::execute), new TrinoSqlExecutor(getQueryRunner()), tableNamePrefix); - } } diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestingClickHouseServer.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestingClickHouseServer.java index 60295eb2babf..7aac0e1c4531 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestingClickHouseServer.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestingClickHouseServer.java @@ -52,7 +52,7 @@ public TestingClickHouseServer() public TestingClickHouseServer(DockerImageName image) { - dockerContainer = (ClickHouseContainer) createContainer(image) + dockerContainer = createContainer(image) .withCopyFileToContainer(forClasspathResource("custom.xml"), "/etc/clickhouse-server/config.d/custom.xml") .withStartupAttempts(10); @@ -93,7 +93,7 @@ public void execute(String sql) public String getJdbcUrl() { - return format("jdbc:clickhouse://%s:%s/", dockerContainer.getContainerIpAddress(), + return format("jdbc:clickhouse://%s:%s/", dockerContainer.getHost(), dockerContainer.getMappedPort(HTTP_PORT)); } diff --git a/plugin/trino-delta-lake/pom.xml b/plugin/trino-delta-lake/pom.xml index e6ae33471091..0cc697731a1e 100644 --- a/plugin/trino-delta-lake/pom.xml +++ b/plugin/trino-delta-lake/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaHiveTypeTranslator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaHiveTypeTranslator.java index b35f6a986645..0c0fa73d7d81 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaHiveTypeTranslator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaHiveTypeTranslator.java @@ -18,7 +18,6 @@ import io.trino.spi.TrinoException; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; -import io.trino.spi.type.NamedTypeSignature; import io.trino.spi.type.TimestampType; import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; @@ -149,11 +148,8 @@ public static TypeInfo translate(Type type) if (!parameter.isNamedTypeSignature()) { throw new IllegalArgumentException(format("Expected all parameters to be named type, but got %s", parameter)); } - NamedTypeSignature namedTypeSignature = parameter.getNamedTypeSignature(); - if (namedTypeSignature.getName().isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, format("Anonymous row type is not supported in Hive. Please give each field a name: %s", type)); - } - fieldNames.add(namedTypeSignature.getName().get()); + fieldNames.add(parameter.getNamedTypeSignature().getName() + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, format("Anonymous row type is not supported in Hive. Please give each field a name: %s", type)))); } return getStructTypeInfo( fieldNames.build(), diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 2ae54cf6a571..04792d1dfc5a 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -56,11 +56,9 @@ import io.trino.plugin.hive.TableAlreadyExistsException; import io.trino.plugin.hive.metastore.Column; import io.trino.plugin.hive.metastore.Database; -import io.trino.plugin.hive.metastore.HivePrincipal; import io.trino.plugin.hive.metastore.PrincipalPrivileges; import io.trino.plugin.hive.metastore.StorageFormat; import io.trino.plugin.hive.metastore.Table; -import io.trino.plugin.hive.security.AccessControlMetadata; import io.trino.spi.NodeManager; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; @@ -98,9 +96,6 @@ import io.trino.spi.expression.Variable; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; -import io.trino.spi.security.GrantInfo; -import io.trino.spi.security.Privilege; -import io.trino.spi.security.RoleGrant; import io.trino.spi.security.TrinoPrincipal; import io.trino.spi.statistics.ColumnStatisticMetadata; import io.trino.spi.statistics.ColumnStatisticType; @@ -273,7 +268,6 @@ public class DeltaLakeMetadata private final TrinoFileSystemFactory fileSystemFactory; private final HdfsEnvironment hdfsEnvironment; private final TypeManager typeManager; - private final AccessControlMetadata accessControlMetadata; private final CheckpointWriterManager checkpointWriterManager; private final long defaultCheckpointInterval; private final boolean ignoreCheckpointWriteFailures; @@ -297,7 +291,6 @@ public DeltaLakeMetadata( TrinoFileSystemFactory fileSystemFactory, HdfsEnvironment hdfsEnvironment, TypeManager typeManager, - AccessControlMetadata accessControlMetadata, int domainCompactionThreshold, boolean unsafeWritesEnabled, JsonCodec dataFileInfoCodec, @@ -318,7 +311,6 @@ public DeltaLakeMetadata( this.fileSystemFactory = requireNonNull(fileSystemFactory, "fileSystemFactory is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); - this.accessControlMetadata = requireNonNull(accessControlMetadata, "accessControlMetadata is null"); this.domainCompactionThreshold = domainCompactionThreshold; this.unsafeWritesEnabled = unsafeWritesEnabled; this.dataFileInfoCodec = requireNonNull(dataFileInfoCodec, "dataFileInfoCodec is null"); @@ -668,15 +660,13 @@ public void createTable(ConnectorSession session, ConnectorTableMetadata tableMe boolean external = true; String location = getLocation(tableMetadata.getProperties()); if (location == null) { - Optional schemaLocation = getSchemaLocation(schema); - if (schemaLocation.isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, "The 'location' property must be specified either for the table or the schema"); - } + String schemaLocation = getSchemaLocation(schema) + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "The 'location' property must be specified either for the table or the schema")); String tableNameForLocation = tableName; if (useUniqueTableLocation) { tableNameForLocation += "-" + randomUUID().toString().replace("-", ""); } - location = new Path(schemaLocation.get(), tableNameForLocation).toString(); + location = new Path(schemaLocation, tableNameForLocation).toString(); checkPathContainsNoFiles(session, new Path(location)); external = false; } @@ -802,15 +792,13 @@ public DeltaLakeOutputTableHandle beginCreateTable(ConnectorSession session, Con boolean external = true; String location = getLocation(tableMetadata.getProperties()); if (location == null) { - Optional schemaLocation = getSchemaLocation(schema); - if (schemaLocation.isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, "The 'location' property must be specified either for the table or the schema"); - } + String schemaLocation = getSchemaLocation(schema) + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "The 'location' property must be specified either for the table or the schema")); String tableNameForLocation = tableName; if (useUniqueTableLocation) { tableNameForLocation += "-" + randomUUID().toString().replace("-", ""); } - location = new Path(schemaLocation.get(), tableNameForLocation).toString(); + location = new Path(schemaLocation, tableNameForLocation).toString(); external = false; } Path targetPath = new Path(location); @@ -1984,12 +1972,10 @@ public void dropTable(ConnectorSession session, ConnectorTableHandle tableHandle { DeltaLakeTableHandle handle = (DeltaLakeTableHandle) tableHandle; - Optional table = metastore.getTable(handle.getSchemaName(), handle.getTableName()); - if (table.isEmpty()) { - throw new TableNotFoundException(handle.getSchemaTableName()); - } + Table table = metastore.getTable(handle.getSchemaName(), handle.getTableName()) + .orElseThrow(() -> new TableNotFoundException(handle.getSchemaTableName())); - metastore.dropTable(session, handle.getSchemaName(), handle.getTableName(), table.get().getTableType().equals(EXTERNAL_TABLE.toString())); + metastore.dropTable(session, handle.getSchemaName(), handle.getTableName(), table.getTableType().equals(EXTERNAL_TABLE.toString())); } @Override @@ -2014,83 +2000,6 @@ public Map getSchemaProperties(ConnectorSession session, Catalog return db.map(DeltaLakeSchemaProperties::fromDatabase).orElseThrow(() -> new SchemaNotFoundException(schema)); } - @Override - public void createRole(ConnectorSession session, String role, Optional grantor) - { - accessControlMetadata.createRole(session, role, grantor.map(HivePrincipal::from)); - } - - @Override - public void dropRole(ConnectorSession session, String role) - { - accessControlMetadata.dropRole(session, role); - } - - @Override - public Set listRoles(ConnectorSession session) - { - return accessControlMetadata.listRoles(session); - } - - @Override - public Set listRoleGrants(ConnectorSession session, TrinoPrincipal principal) - { - return ImmutableSet.copyOf(accessControlMetadata.listRoleGrants(session, HivePrincipal.from(principal))); - } - - @Override - public void grantRoles(ConnectorSession session, Set roles, Set grantees, boolean withAdminOption, Optional grantor) - { - accessControlMetadata.grantRoles(session, roles, HivePrincipal.from(grantees), withAdminOption, grantor.map(HivePrincipal::from)); - } - - @Override - public void revokeRoles(ConnectorSession session, Set roles, Set grantees, boolean adminOptionFor, Optional grantor) - { - accessControlMetadata.revokeRoles(session, roles, HivePrincipal.from(grantees), adminOptionFor, grantor.map(HivePrincipal::from)); - } - - @Override - public Set listApplicableRoles(ConnectorSession session, TrinoPrincipal principal) - { - return accessControlMetadata.listApplicableRoles(session, HivePrincipal.from(principal)); - } - - @Override - public Set listEnabledRoles(ConnectorSession session) - { - return accessControlMetadata.listEnabledRoles(session); - } - - @Override - public void grantTablePrivileges(ConnectorSession session, SchemaTableName schemaTableName, Set privileges, TrinoPrincipal grantee, boolean grantOption) - { - accessControlMetadata.grantTablePrivileges(session, schemaTableName, privileges, HivePrincipal.from(grantee), grantOption); - } - - @Override - public void revokeTablePrivileges(ConnectorSession session, SchemaTableName schemaTableName, Set privileges, TrinoPrincipal grantee, boolean grantOption) - { - accessControlMetadata.revokeTablePrivileges(session, schemaTableName, privileges, HivePrincipal.from(grantee), grantOption); - } - - @Override - public List listTablePrivileges(ConnectorSession session, SchemaTablePrefix schemaTablePrefix) - { - return accessControlMetadata.listTablePrivileges(session, listTables(session, schemaTablePrefix)); - } - - private List listTables(ConnectorSession session, SchemaTablePrefix prefix) - { - if (prefix.getTable().isEmpty()) { - return listTables(session, prefix.getSchema()); - } - SchemaTableName tableName = prefix.toSchemaTableName(); - return metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) - .map(table -> ImmutableList.of(tableName)) - .orElse(ImmutableList.of()); - } - private void setRollback(Runnable action) { checkState(rollbackAction.compareAndSet(null, action), "rollback action is already set"); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java index e8fd6841d953..4c0a999e1965 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java @@ -42,7 +42,6 @@ public class DeltaLakeMetadataFactory private final HdfsEnvironment hdfsEnvironment; private final TransactionLogAccess transactionLogAccess; private final TypeManager typeManager; - private final DeltaLakeAccessControlMetadataFactory accessControlMetadataFactory; private final JsonCodec dataFileInfoCodec; private final JsonCodec updateResultJsonCodec; private final JsonCodec mergeResultJsonCodec; @@ -68,7 +67,6 @@ public DeltaLakeMetadataFactory( HdfsEnvironment hdfsEnvironment, TransactionLogAccess transactionLogAccess, TypeManager typeManager, - DeltaLakeAccessControlMetadataFactory accessControlMetadataFactory, DeltaLakeConfig deltaLakeConfig, JsonCodec dataFileInfoCodec, JsonCodec updateResultJsonCodec, @@ -85,7 +83,6 @@ public DeltaLakeMetadataFactory( this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.transactionLogAccess = requireNonNull(transactionLogAccess, "transactionLogAccess is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); - this.accessControlMetadataFactory = requireNonNull(accessControlMetadataFactory, "accessControlMetadataFactory is null"); this.dataFileInfoCodec = requireNonNull(dataFileInfoCodec, "dataFileInfoCodec is null"); this.updateResultJsonCodec = requireNonNull(updateResultJsonCodec, "updateResultJsonCodec is null"); this.mergeResultJsonCodec = requireNonNull(mergeResultJsonCodec, "mergeResultJsonCodec is null"); @@ -123,7 +120,6 @@ public DeltaLakeMetadata create(ConnectorIdentity identity) fileSystemFactory, hdfsEnvironment, typeManager, - accessControlMetadataFactory.create(cachingHiveMetastore), domainCompactionThreshold, unsafeWritesEnabled, dataFileInfoCodec, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java index 67e1e0264720..946bd6707e1f 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java @@ -96,8 +96,7 @@ public void setup(Binder binder) configBinder(binder).bindConfigDefaults(ParquetWriterConfig.class, config -> config.setParquetOptimizedWriterEnabled(true)); install(new ConnectorAccessControlModule()); - newOptionalBinder(binder, DeltaLakeAccessControlMetadataFactory.class) - .setDefault().toInstance(DeltaLakeAccessControlMetadataFactory.SYSTEM); + configBinder(binder).bindConfig(DeltaLakeSecurityConfig.class); Multibinder systemTableProviders = newSetBinder(binder, SystemTableProvider.class); systemTableProviders.addBinding().to(PropertiesSystemTableProvider.class).in(Scopes.SINGLETON); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSecurityConfig.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSecurityConfig.java new file mode 100644 index 000000000000..ef967b9c7344 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSecurityConfig.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +import javax.validation.constraints.NotNull; + +import static io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity.ALLOW_ALL; + +public class DeltaLakeSecurityConfig +{ + public enum DeltaLakeSecurity + { + ALLOW_ALL, + READ_ONLY, + SYSTEM, + FILE, + } + + private DeltaLakeSecurity securitySystem = ALLOW_ALL; + + @NotNull + public DeltaLakeSecurity getSecuritySystem() + { + return securitySystem; + } + + @Config("delta.security") + @ConfigDescription("Authorization checks for Delta Lake connector") + public DeltaLakeSecurityConfig setSecuritySystem(DeltaLakeSecurity securitySystem) + { + this.securitySystem = securitySystem; + return this; + }} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSecurityModule.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSecurityModule.java new file mode 100644 index 000000000000..226643973093 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSecurityModule.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.inject.Binder; +import com.google.inject.Module; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.base.security.ConnectorAccessControlModule; +import io.trino.plugin.base.security.FileBasedAccessControlModule; +import io.trino.plugin.base.security.ReadOnlySecurityModule; +import io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity; +import io.trino.plugin.hive.security.AllowAllSecurityModule; + +import static io.airlift.configuration.ConditionalModule.conditionalModule; +import static io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity.ALLOW_ALL; +import static io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity.FILE; +import static io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity.READ_ONLY; + +public class DeltaLakeSecurityModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + install(new ConnectorAccessControlModule()); + bindSecurityModule(ALLOW_ALL, new AllowAllSecurityModule()); + bindSecurityModule(READ_ONLY, new ReadOnlySecurityModule()); + bindSecurityModule(FILE, new FileBasedAccessControlModule()); + // SYSTEM: do not bind an ConnectorAccessControl so the engine will use system security with system roles + } + + private void bindSecurityModule(DeltaLakeSecurity deltaLakeSecurity, Module module) + { + install(conditionalModule( + DeltaLakeSecurityConfig.class, + security -> deltaLakeSecurity == security.getSecuritySystem(), + module)); + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java index 593959e629e9..4cb172bf5cbb 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java @@ -91,6 +91,7 @@ public static Connector createConnector( new CatalogNameModule(catalogName), new DeltaLakeMetastoreModule(), new DeltaLakeModule(), + new DeltaLakeSecurityModule(), binder -> { binder.bind(NodeVersion.class).toInstance(new NodeVersion(context.getNodeManager().getCurrentNode().getVersion())); binder.bind(NodeManager.class).toInstance(context.getNodeManager()); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeParquetStatisticsUtils.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeParquetStatisticsUtils.java index 448e5650787f..74f0d7526d23 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeParquetStatisticsUtils.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeParquetStatisticsUtils.java @@ -14,9 +14,17 @@ package io.trino.plugin.deltalake.transactionlog; import io.airlift.log.Logger; +import io.airlift.slice.Slice; import io.trino.plugin.base.type.DecodedTimestamp; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.DateType; import io.trino.spi.type.DecimalType; +import io.trino.spi.type.Decimals; +import io.trino.spi.type.Int128; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; @@ -29,26 +37,40 @@ import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; import org.apache.parquet.schema.LogicalTypeAnnotation; +import javax.annotation.Nullable; + import java.math.BigDecimal; import java.math.BigInteger; import java.time.Instant; import java.time.LocalDate; import java.time.ZonedDateTime; import java.util.Collection; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.slice.Slices.utf8Slice; import static io.trino.parquet.ParquetTimestampUtils.decodeInt96Timestamp; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.TypeUtils.writeNativeValue; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; import static java.nio.charset.StandardCharsets.UTF_8; import static java.time.ZoneOffset.UTC; import static java.time.format.DateTimeFormatter.ISO_INSTANT; @@ -71,6 +93,114 @@ public static boolean hasInvalidStatistics(Collection metad (!metadata.getStatistics().hasNonNullValue() && metadata.getStatistics().getNumNulls() != metadata.getValueCount())); } + @Nullable + public static Object jsonValueToTrinoValue(Type type, @Nullable Object jsonValue) + { + if (jsonValue == null) { + return null; + } + + if (type == SMALLINT || type == TINYINT || type == INTEGER) { + return (long) (int) jsonValue; + } + if (type == BIGINT) { + return (long) (int) jsonValue; + } + if (type == REAL) { + return (long) floatToRawIntBits((float) (double) jsonValue); + } + if (type == DOUBLE) { + return (double) jsonValue; + } + if (type instanceof DecimalType decimalType) { + BigDecimal decimal = new BigDecimal((String) jsonValue); + + if (decimalType.isShort()) { + return Decimals.encodeShortScaledValue(decimal, decimalType.getScale()); + } + return Decimals.encodeScaledValue(decimal, decimalType.getScale()); + } + if (type instanceof VarcharType) { + return utf8Slice((String) jsonValue); + } + if (type == DateType.DATE) { + return LocalDate.parse((String) jsonValue).toEpochDay(); + } + if (type == TIMESTAMP_MILLIS) { + return Instant.parse((String) jsonValue).toEpochMilli() * MICROSECONDS_PER_MILLISECOND; + } + if (type instanceof RowType rowType) { + Map values = (Map) jsonValue; + List fieldTypes = rowType.getTypeParameters(); + BlockBuilder blockBuilder = new RowBlockBuilder(fieldTypes, null, 1); + BlockBuilder singleRowBlockWriter = blockBuilder.beginBlockEntry(); + for (int i = 0; i < values.size(); ++i) { + Type fieldType = fieldTypes.get(i); + String fieldName = rowType.getFields().get(i).getName().orElseThrow(() -> new IllegalArgumentException("Field name must exist")); + Object fieldValue = jsonValueToTrinoValue(fieldType, values.remove(fieldName)); + writeNativeValue(fieldType, singleRowBlockWriter, fieldValue); + } + checkState(values.isEmpty(), "All fields must be converted into Trino value: %s", values); + + blockBuilder.closeEntry(); + return blockBuilder.build(); + } + + throw new UnsupportedOperationException("Unsupported type: " + type); + } + + public static Map toJsonValues(Map columnTypeMapping, Map values) + { + Map jsonValues = new HashMap<>(); + for (Map.Entry value : values.entrySet()) { + Type type = columnTypeMapping.get(value.getKey()); + // TODO: Add support for row type + if (type instanceof ArrayType || type instanceof MapType || type instanceof RowType) { + continue; + } + jsonValues.put(value.getKey(), toJsonValue(columnTypeMapping.get(value.getKey()), value.getValue())); + } + return jsonValues; + } + + @Nullable + private static Object toJsonValue(Type type, @Nullable Object value) + { + if (value == null) { + return null; + } + + if (type == SMALLINT || type == TINYINT || type == INTEGER || type == BIGINT) { + return value; + } + if (type == REAL) { + return intBitsToFloat(toIntExact((long) value)); + } + if (type == DOUBLE) { + return value; + } + if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + if (decimalType.isShort()) { + return Decimals.toString((long) value, decimalType.getScale()); + } + return Decimals.toString((Int128) value, decimalType.getScale()); + } + + if (type instanceof VarcharType) { + return ((Slice) value).toStringUtf8(); + } + if (type == DateType.DATE) { + return LocalDate.ofEpochDay((long) value).format(ISO_LOCAL_DATE); + } + if (type == TIMESTAMP_TZ_MILLIS) { + Instant ts = Instant.ofEpochMilli(unpackMillisUtc((long) value)); + return ISO_INSTANT.format(ZonedDateTime.ofInstant(ts, UTC)); + } + + throw new UnsupportedOperationException("Unsupported type: " + type); + } + public static Map jsonEncodeMin(Map>> stats, Map typeForColumn) { return jsonEncode(stats, typeForColumn, DeltaLakeParquetStatisticsUtils::getMin); @@ -110,7 +240,7 @@ private static Optional getMin(Type type, Statistics statistics) Instant ts = Instant.ofEpochMilli(((LongStatistics) statistics).genericGetMin()); return Optional.of(ISO_INSTANT.format(ZonedDateTime.ofInstant(ts, UTC))); } - else if (statistics instanceof BinaryStatistics) { + if (statistics instanceof BinaryStatistics) { DecodedTimestamp decodedTimestamp = decodeInt96Timestamp(((BinaryStatistics) statistics).genericGetMin()); Instant ts = Instant.ofEpochSecond(decodedTimestamp.getEpochSeconds(), decodedTimestamp.getNanosOfSecond()); return Optional.of(ISO_INSTANT.format(ZonedDateTime.ofInstant(ts, UTC).truncatedTo(MILLIS))); @@ -145,11 +275,11 @@ else if (statistics instanceof BinaryStatistics) { min = BigDecimal.valueOf(((IntStatistics) statistics).getMin()).movePointLeft(scale); return Optional.of(min.toPlainString()); } - else if (statistics instanceof LongStatistics) { + if (statistics instanceof LongStatistics) { min = BigDecimal.valueOf(((LongStatistics) statistics).getMin()).movePointLeft(scale); return Optional.of(min.toPlainString()); } - else if (statistics instanceof BinaryStatistics) { + if (statistics instanceof BinaryStatistics) { BigInteger base = new BigInteger(((BinaryStatistics) statistics).genericGetMin().getBytes()); min = new BigDecimal(base, scale); return Optional.of(min.toPlainString()); @@ -187,7 +317,7 @@ private static Optional getMax(Type type, Statistics statistics) Instant ts = Instant.ofEpochMilli(((LongStatistics) statistics).genericGetMax()); return Optional.of(ISO_INSTANT.format(ZonedDateTime.ofInstant(ts, UTC))); } - else if (statistics instanceof BinaryStatistics) { + if (statistics instanceof BinaryStatistics) { DecodedTimestamp decodedTimestamp = decodeInt96Timestamp(((BinaryStatistics) statistics).genericGetMax()); Instant ts = Instant.ofEpochSecond(decodedTimestamp.getEpochSeconds(), decodedTimestamp.getNanosOfSecond()); ZonedDateTime zonedDateTime = ZonedDateTime.ofInstant(ts, UTC); @@ -225,11 +355,11 @@ else if (statistics instanceof BinaryStatistics) { max = BigDecimal.valueOf(((IntStatistics) statistics).getMax()).movePointLeft(scale); return Optional.of(max.toPlainString()); } - else if (statistics instanceof LongStatistics) { + if (statistics instanceof LongStatistics) { max = BigDecimal.valueOf(((LongStatistics) statistics).getMax()).movePointLeft(scale); return Optional.of(max.toPlainString()); } - else if (statistics instanceof BinaryStatistics) { + if (statistics instanceof BinaryStatistics) { BigInteger base = new BigInteger(((BinaryStatistics) statistics).genericGetMax().getBytes()); max = new BigDecimal(base, scale); return Optional.of(max.toPlainString()); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/MetadataEntry.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/MetadataEntry.java index b0fd34c4222b..23c1df2aef68 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/MetadataEntry.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/MetadataEntry.java @@ -32,6 +32,9 @@ public class MetadataEntry { + public static final String DELTA_CHECKPOINT_WRITE_STATS_AS_JSON_PROPERTY = "delta.checkpoint.writeStatsAsJson"; + public static final String DELTA_CHECKPOINT_WRITE_STATS_AS_STRUCT_PROPERTY = "delta.checkpoint.writeStatsAsStruct"; + private static final String DELTA_CHECKPOINT_INTERVAL_PROPERTY = "delta.checkpointInterval"; private final String id; diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java index a823d08b817b..5f0454634122 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java @@ -209,7 +209,7 @@ public List getActiveFiles(TableSnapshot tableSnapshot, ConnectorS log.warn("Query run with outdated Transaction Log Snapshot, retrieved stale table entries for table: %s and query %s", tableSnapshot.getTable(), session.getQueryId()); return loadActiveFiles(tableSnapshot, session); } - else if (cachedTable.getVersion() < tableSnapshot.getVersion()) { + if (cachedTable.getVersion() < tableSnapshot.getVersion()) { DeltaLakeDataFileCacheEntry updatedCacheEntry; try { List newEntries = getJsonEntries( diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 8fe869bd502e..b54c5409f731 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -201,7 +201,7 @@ private DeltaLakeColumnHandle buildColumnHandle(EntryType entryType, CheckpointS type = schemaManager.getTxnEntryType(); break; case ADD: - type = schemaManager.getAddEntryType(metadataEntry); + type = schemaManager.getAddEntryType(metadataEntry, true, true); break; case REMOVE: type = schemaManager.getRemoveEntryType(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointSchemaManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointSchemaManager.java index 4804cb935eb6..af99d18f3aa6 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointSchemaManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointSchemaManager.java @@ -106,7 +106,7 @@ public RowType getMetadataEntryType() return metadataEntryType; } - public RowType getAddEntryType(MetadataEntry metadataEntry) + public RowType getAddEntryType(MetadataEntry metadataEntry, boolean requireWriteStatsAsJson, boolean requireWriteStatsAsStruct) { List allColumns = extractSchema(metadataEntry, typeManager); List minMaxColumns = columnsWithStats(metadataEntry, typeManager); @@ -137,17 +137,21 @@ public RowType getAddEntryType(MetadataEntry metadataEntry) RowType.from(allColumns.stream().map(column -> buildNullCountType(Optional.of(column.getPhysicalName()), column.getPhysicalColumnType())).collect(toImmutableList())))); MapType stringMap = (MapType) typeManager.getType(TypeSignature.mapType(VarcharType.VARCHAR.getTypeSignature(), VarcharType.VARCHAR.getTypeSignature())); - List addFields = ImmutableList.of( - RowType.field("path", VarcharType.createUnboundedVarcharType()), - RowType.field("partitionValues", stringMap), - RowType.field("size", BigintType.BIGINT), - RowType.field("modificationTime", BigintType.BIGINT), - RowType.field("dataChange", BooleanType.BOOLEAN), - RowType.field("stats", VarcharType.createUnboundedVarcharType()), - RowType.field("stats_parsed", RowType.from(statsColumns.build())), - RowType.field("tags", stringMap)); - - return RowType.from(addFields); + ImmutableList.Builder addFields = ImmutableList.builder(); + addFields.add(RowType.field("path", VarcharType.createUnboundedVarcharType())); + addFields.add(RowType.field("partitionValues", stringMap)); + addFields.add(RowType.field("size", BigintType.BIGINT)); + addFields.add(RowType.field("modificationTime", BigintType.BIGINT)); + addFields.add(RowType.field("dataChange", BooleanType.BOOLEAN)); + if (requireWriteStatsAsJson) { + addFields.add(RowType.field("stats", VarcharType.createUnboundedVarcharType())); + } + if (requireWriteStatsAsStruct) { + addFields.add(RowType.field("stats_parsed", RowType.from(statsColumns.build()))); + } + addFields.add(RowType.field("tags", stringMap)); + + return RowType.from(addFields.build()); } private static RowType.Field buildNullCountType(Optional columnName, Type columnType) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriter.java index 122c875c87d0..dbaf9392ba4d 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriter.java @@ -13,14 +13,17 @@ */ package io.trino.plugin.deltalake.transactionlog.checkpoint; +import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.collect.ImmutableList; import io.trino.hdfs.HdfsContext; import io.trino.hdfs.HdfsEnvironment; +import io.trino.plugin.deltalake.DeltaLakeColumnMetadata; import io.trino.plugin.deltalake.transactionlog.AddFileEntry; import io.trino.plugin.deltalake.transactionlog.MetadataEntry; import io.trino.plugin.deltalake.transactionlog.ProtocolEntry; import io.trino.plugin.deltalake.transactionlog.RemoveFileEntry; import io.trino.plugin.deltalake.transactionlog.TransactionEntry; +import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeFileStatistics; import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeJsonFileStatistics; import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeParquetFileStatistics; import io.trino.plugin.hive.RecordFileWriter; @@ -49,9 +52,16 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.hdfs.ConfigurationUtils.toJobConf; import static io.trino.plugin.deltalake.DeltaLakeSchemaProperties.buildHiveSchema; +import static io.trino.plugin.deltalake.transactionlog.DeltaLakeParquetStatisticsUtils.jsonValueToTrinoValue; +import static io.trino.plugin.deltalake.transactionlog.DeltaLakeParquetStatisticsUtils.toJsonValues; +import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; +import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.serializeStatsAsJson; +import static io.trino.plugin.deltalake.transactionlog.MetadataEntry.DELTA_CHECKPOINT_WRITE_STATS_AS_JSON_PROPERTY; +import static io.trino.plugin.deltalake.transactionlog.MetadataEntry.DELTA_CHECKPOINT_WRITE_STATS_AS_STRUCT_PROPERTY; import static io.trino.plugin.hive.HiveCompressionCodec.SNAPPY; import static io.trino.plugin.hive.HiveStorageFormat.PARQUET; import static io.trino.plugin.hive.metastore.StorageFormat.fromHiveStorageFormat; @@ -59,6 +69,7 @@ import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.TypeUtils.writeNativeValue; import static java.lang.Math.multiplyExact; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; @@ -86,10 +97,15 @@ public CheckpointWriter(TypeManager typeManager, CheckpointSchemaManager checkpo public void write(ConnectorSession session, CheckpointEntries entries, Path targetPath) { + Map configuration = entries.getMetadataEntry().getConfiguration(); + boolean writeStatsAsJson = Boolean.parseBoolean(configuration.getOrDefault(DELTA_CHECKPOINT_WRITE_STATS_AS_JSON_PROPERTY, "true")); + // The default value is false in https://github.com/delta-io/delta/blob/master/PROTOCOL.md#checkpoint-format, but Databricks defaults to true + boolean writeStatsAsStruct = Boolean.parseBoolean(configuration.getOrDefault(DELTA_CHECKPOINT_WRITE_STATS_AS_STRUCT_PROPERTY, "true")); + RowType metadataEntryType = checkpointSchemaManager.getMetadataEntryType(); RowType protocolEntryType = checkpointSchemaManager.getProtocolEntryType(); RowType txnEntryType = checkpointSchemaManager.getTxnEntryType(); - RowType addEntryType = checkpointSchemaManager.getAddEntryType(entries.getMetadataEntry()); + RowType addEntryType = checkpointSchemaManager.getAddEntryType(entries.getMetadataEntry(), writeStatsAsJson, writeStatsAsStruct); RowType removeEntryType = checkpointSchemaManager.getRemoveEntryType(); List columnNames = ImmutableList.of( @@ -130,7 +146,7 @@ public void write(ConnectorSession session, CheckpointEntries entries, Path targ writeTransactionEntry(pageBuilder, txnEntryType, transactionEntry); } for (AddFileEntry addFileEntry : entries.getAddFileEntries()) { - writeAddFileEntry(pageBuilder, addEntryType, addFileEntry); + writeAddFileEntry(pageBuilder, addEntryType, addFileEntry, entries.getMetadataEntry(), writeStatsAsJson, writeStatsAsStruct); } for (RemoveFileEntry removeFileEntry : entries.getRemoveFileEntries()) { writeRemoveFileEntry(pageBuilder, removeEntryType, removeFileEntry); @@ -193,60 +209,107 @@ private void writeTransactionEntry(PageBuilder pageBuilder, RowType entryType, T appendNullOtherBlocks(pageBuilder, TXN_BLOCK_CHANNEL); } - private void writeAddFileEntry(PageBuilder pageBuilder, RowType entryType, AddFileEntry addFileEntry) + private void writeAddFileEntry(PageBuilder pageBuilder, RowType entryType, AddFileEntry addFileEntry, MetadataEntry metadataEntry, boolean writeStatsAsJson, boolean writeStatsAsStruct) { pageBuilder.declarePosition(); BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(ADD_BLOCK_CHANNEL); BlockBuilder entryBlockBuilder = blockBuilder.beginBlockEntry(); - writeString(entryBlockBuilder, entryType, 0, "path", addFileEntry.getPath()); - writeStringMap(entryBlockBuilder, entryType, 1, "partitionValues", addFileEntry.getPartitionValues()); - writeLong(entryBlockBuilder, entryType, 2, "size", addFileEntry.getSize()); - writeLong(entryBlockBuilder, entryType, 3, "modificationTime", addFileEntry.getModificationTime()); - writeBoolean(entryBlockBuilder, entryType, 4, "dataChange", addFileEntry.isDataChange()); - // TODO: determine stats format in checkpoint based on table configuration; (https://github.com/trinodb/trino/issues/12031) - // currently if addFileEntry contains JSON stats we will write JSON - // stats to checkpoint and if addEntryFile contains parsed stats, we - // will write parsed stats to the checkpoint. - writeJsonStats(entryBlockBuilder, entryType, addFileEntry); - writeParsedStats(entryBlockBuilder, entryType, addFileEntry); - writeStringMap(entryBlockBuilder, entryType, 7, "tags", addFileEntry.getTags()); + int fieldId = 0; + writeString(entryBlockBuilder, entryType, fieldId++, "path", addFileEntry.getPath()); + writeStringMap(entryBlockBuilder, entryType, fieldId++, "partitionValues", addFileEntry.getPartitionValues()); + writeLong(entryBlockBuilder, entryType, fieldId++, "size", addFileEntry.getSize()); + writeLong(entryBlockBuilder, entryType, fieldId++, "modificationTime", addFileEntry.getModificationTime()); + writeBoolean(entryBlockBuilder, entryType, fieldId++, "dataChange", addFileEntry.isDataChange()); + if (writeStatsAsJson) { + writeJsonStats(entryBlockBuilder, entryType, addFileEntry, metadataEntry, fieldId++); + } + if (writeStatsAsStruct) { + writeParsedStats(entryBlockBuilder, entryType, addFileEntry, fieldId++); + } + writeStringMap(entryBlockBuilder, entryType, fieldId++, "tags", addFileEntry.getTags()); blockBuilder.closeEntry(); // null for others appendNullOtherBlocks(pageBuilder, ADD_BLOCK_CHANNEL); } - private void writeJsonStats(BlockBuilder entryBlockBuilder, RowType entryType, AddFileEntry addFileEntry) + private void writeJsonStats(BlockBuilder entryBlockBuilder, RowType entryType, AddFileEntry addFileEntry, MetadataEntry metadataEntry, int fieldId) { String statsJson = null; - if (addFileEntry.getStats().isPresent() && addFileEntry.getStats().get() instanceof DeltaLakeJsonFileStatistics) { - statsJson = addFileEntry.getStatsString().orElse(null); + if (addFileEntry.getStats().isPresent()) { + DeltaLakeFileStatistics statistics = addFileEntry.getStats().get(); + if (statistics instanceof DeltaLakeParquetFileStatistics parquetFileStatistics) { + Map columnTypeMapping = getColumnTypeMapping(metadataEntry); + DeltaLakeJsonFileStatistics jsonFileStatistics = new DeltaLakeJsonFileStatistics( + parquetFileStatistics.getNumRecords(), + parquetFileStatistics.getMinValues().map(values -> toJsonValues(columnTypeMapping, values)), + parquetFileStatistics.getMaxValues().map(values -> toJsonValues(columnTypeMapping, values)), + parquetFileStatistics.getNullCount()); + statsJson = getStatsString(jsonFileStatistics).orElse(null); + } + else { + statsJson = addFileEntry.getStatsString().orElse(null); + } + } + writeString(entryBlockBuilder, entryType, fieldId, "stats", statsJson); + } + + private Map getColumnTypeMapping(MetadataEntry deltaMetadata) + { + return extractSchema(deltaMetadata, typeManager).stream() + .collect(toImmutableMap(DeltaLakeColumnMetadata::getName, DeltaLakeColumnMetadata::getType)); + } + + private Optional getStatsString(DeltaLakeJsonFileStatistics parsedStats) + { + try { + return Optional.of(serializeStatsAsJson(parsedStats)); + } + catch (JsonProcessingException e) { + return Optional.empty(); } - writeString(entryBlockBuilder, entryType, 5, "stats", statsJson); } - private void writeParsedStats(BlockBuilder entryBlockBuilder, RowType entryType, AddFileEntry addFileEntry) + private void writeParsedStats(BlockBuilder entryBlockBuilder, RowType entryType, AddFileEntry addFileEntry, int fieldId) { - RowType statsType = getInternalRowType(entryType, 6, "stats_parsed"); - if (addFileEntry.getStats().isEmpty() || !(addFileEntry.getStats().get() instanceof DeltaLakeParquetFileStatistics)) { + RowType statsType = getInternalRowType(entryType, fieldId, "stats_parsed"); + if (addFileEntry.getStats().isEmpty()) { entryBlockBuilder.appendNull(); return; } - DeltaLakeParquetFileStatistics stats = (DeltaLakeParquetFileStatistics) addFileEntry.getStats().get(); + DeltaLakeFileStatistics stats = addFileEntry.getStats().get(); BlockBuilder statsBlockBuilder = entryBlockBuilder.beginBlockEntry(); - writeLong(statsBlockBuilder, statsType, 0, "numRecords", stats.getNumRecords().orElse(null)); - writeMinMaxMapAsFields(statsBlockBuilder, statsType, 1, "minValues", stats.getMinValues()); - writeMinMaxMapAsFields(statsBlockBuilder, statsType, 2, "maxValues", stats.getMaxValues()); - writeObjectMapAsFields(statsBlockBuilder, statsType, 3, "nullCount", stats.getNullCount()); + if (stats instanceof DeltaLakeParquetFileStatistics) { + writeLong(statsBlockBuilder, statsType, 0, "numRecords", stats.getNumRecords().orElse(null)); + writeMinMaxMapAsFields(statsBlockBuilder, statsType, 1, "minValues", stats.getMinValues(), false); + writeMinMaxMapAsFields(statsBlockBuilder, statsType, 2, "maxValues", stats.getMaxValues(), false); + writeNullCountAsFields(statsBlockBuilder, statsType, 3, "nullCount", stats.getNullCount()); + } + else { + int internalFieldId = 0; + writeLong(statsBlockBuilder, statsType, internalFieldId++, "numRecords", stats.getNumRecords().orElse(null)); + if (statsType.getFields().stream().anyMatch(field -> field.getName().orElseThrow().equals("minValues"))) { + writeMinMaxMapAsFields(statsBlockBuilder, statsType, internalFieldId++, "minValues", stats.getMinValues(), true); + } + if (statsType.getFields().stream().anyMatch(field -> field.getName().orElseThrow().equals("maxValues"))) { + writeMinMaxMapAsFields(statsBlockBuilder, statsType, internalFieldId++, "maxValues", stats.getMaxValues(), true); + } + writeNullCountAsFields(statsBlockBuilder, statsType, internalFieldId++, "nullCount", stats.getNullCount()); + } entryBlockBuilder.closeEntry(); } - private void writeMinMaxMapAsFields(BlockBuilder blockBuilder, RowType type, int fieldId, String fieldName, Optional> values) + private void writeMinMaxMapAsFields(BlockBuilder blockBuilder, RowType type, int fieldId, String fieldName, Optional> values, boolean isJson) { RowType.Field valuesField = validateAndGetField(type, fieldId, fieldName); RowType valuesFieldType = (RowType) valuesField.getType(); - writeObjectMapAsFields(blockBuilder, type, fieldId, fieldName, preprocessMinMaxValues(valuesFieldType, values)); + writeObjectMapAsFields(blockBuilder, type, fieldId, fieldName, preprocessMinMaxValues(valuesFieldType, values, isJson)); + } + + private void writeNullCountAsFields(BlockBuilder blockBuilder, RowType type, int fieldId, String fieldName, Optional> values) + { + writeObjectMapAsFields(blockBuilder, type, fieldId, fieldName, preprocessNullCount(values)); } private void writeObjectMapAsFields(BlockBuilder blockBuilder, RowType type, int fieldId, String fieldName, Optional> values) @@ -263,6 +326,11 @@ private void writeObjectMapAsFields(BlockBuilder blockBuilder, RowType type, int Object value = values.get().get(valueField.getName().orElseThrow()); if (valueField.getType() instanceof RowType) { Block rowBlock = (Block) value; + // Statistics were not collected + if (rowBlock == null) { + fieldBlockBuilder.appendNull(); + continue; + } checkState(rowBlock.getPositionCount() == 1, "Invalid RowType statistics for writing Delta Lake checkpoint"); if (rowBlock.isNull(0)) { fieldBlockBuilder.appendNull(); @@ -279,7 +347,7 @@ private void writeObjectMapAsFields(BlockBuilder blockBuilder, RowType type, int blockBuilder.closeEntry(); } - private Optional> preprocessMinMaxValues(RowType valuesType, Optional> valuesOptional) + private Optional> preprocessMinMaxValues(RowType valuesType, Optional> valuesOptional, boolean isJson) { return valuesOptional.map( values -> { @@ -292,8 +360,11 @@ private Optional> preprocessMinMaxValues(RowType valuesType, .collect(toMap( Map.Entry::getKey, entry -> { - Type type = fieldTypes.get(entry.getKey()); + Type type = fieldTypes.get(entry.getKey().toLowerCase(ENGLISH)); Object value = entry.getValue(); + if (isJson) { + return jsonValueToTrinoValue(type, value); + } if (type instanceof TimestampType) { // We need to remap TIMESTAMP WITH TIME ZONE -> TIMESTAMP here because of // inconsistency in what type is used for DL "timestamp" type in data processing and in min/max statistics map. @@ -304,6 +375,22 @@ private Optional> preprocessMinMaxValues(RowType valuesType, }); } + private Optional> preprocessNullCount(Optional> valuesOptional) + { + return valuesOptional.map( + values -> + values.entrySet().stream() + .collect(toMap( + Map.Entry::getKey, + entry -> { + Object value = entry.getValue(); + if (value instanceof Integer) { + return (long) (int) value; + } + return value; + }))); + } + private void writeRemoveFileEntry(PageBuilder pageBuilder, RowType entryType, RemoveFileEntry removeFileEntry) { pageBuilder.declarePosition(); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TransactionLogTail.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TransactionLogTail.java index b0771db6c966..11db3114a427 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TransactionLogTail.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TransactionLogTail.java @@ -84,9 +84,7 @@ public static TransactionLogTail loadNewTail( if (endVersion.isPresent()) { throw new MissingTransactionLogException(path); } - else { - endOfTail = true; - } + endOfTail = true; } if (endVersion.isPresent() && version == endVersion.get()) { @@ -157,12 +155,10 @@ public static boolean isFileNotFoundException(IOException e) if (e instanceof FileNotFoundException) { return true; } - else if (e.getMessage().contains("The specified key does not exist")) { + if (e.getMessage().contains("The specified key does not exist")) { return true; } - else { - return false; - } + return false; } public List getFileEntries() diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeFileStatistics.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeFileStatistics.java index 562b209b07ce..9c7a2ef3c540 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeFileStatistics.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeFileStatistics.java @@ -27,6 +27,12 @@ public interface DeltaLakeFileStatistics { Optional getNumRecords(); + Optional> getMinValues(); + + Optional> getMaxValues(); + + Optional> getNullCount(); + Optional getNullCount(String columnName); Optional getMinColumnValue(DeltaLakeColumnHandle columnHandle); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeJsonFileStatistics.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeJsonFileStatistics.java index 06b2ad4b629e..d7ee8dba5284 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeJsonFileStatistics.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeJsonFileStatistics.java @@ -89,18 +89,21 @@ public Optional getNumRecords() } @JsonProperty + @Override public Optional> getMinValues() { return minValues.map(TransactionLogAccess::toOriginalNameKeyedMap); } @JsonProperty + @Override public Optional> getMaxValues() { return maxValues.map(TransactionLogAccess::toOriginalNameKeyedMap); } @JsonProperty + @Override public Optional> getNullCount() { return nullCount.map(TransactionLogAccess::toOriginalNameKeyedMap); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeParquetFileStatistics.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeParquetFileStatistics.java index 0df33917623b..618c9a6da6a4 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeParquetFileStatistics.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/statistics/DeltaLakeParquetFileStatistics.java @@ -61,16 +61,19 @@ public Optional getNumRecords() return numRecords; } + @Override public Optional> getMinValues() { return minValues.map(TransactionLogAccess::toOriginalNameKeyedMap); } + @Override public Optional> getMaxValues() { return maxValues.map(TransactionLogAccess::toOriginalNameKeyedMap); } + @Override public Optional> getNullCount() { return nullCount.map(TransactionLogAccess::toOriginalNameKeyedMap); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/AbstractTestDeltaLakeCreateTableStatistics.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/AbstractTestDeltaLakeCreateTableStatistics.java index c9ca30616c1b..df09a84faa0f 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/AbstractTestDeltaLakeCreateTableStatistics.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/AbstractTestDeltaLakeCreateTableStatistics.java @@ -57,8 +57,6 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.testing.TestingConnectorSession.SESSION; -import static io.trino.testing.TestingSession.testSessionBuilder; -import static io.trino.testing.assertions.Assert.assertEventually; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.POSITIVE_INFINITY; @@ -81,8 +79,9 @@ protected QueryRunner createQueryRunner() this.bucketName = "delta-test-create-table-statistics-" + randomTableSuffix(); HiveMinioDataLake hiveMinioDataLake = closeAfterClass(new HiveMinioDataLake(bucketName)); hiveMinioDataLake.start(); - ImmutableMap.Builder queryRunnerProperties = ImmutableMap.builder(); - queryRunnerProperties.putAll(additionalProperties()); + ImmutableMap.Builder queryRunnerProperties = ImmutableMap.builder() + .put("delta.enable-non-concurrent-writes", "true") + .putAll(additionalProperties()); return DeltaLakeQueryRunner.createS3DeltaLakeQueryRunner( DELTA_CATALOG, SCHEMA, @@ -433,67 +432,29 @@ else if (addFileEntry.getPartitionValues().get(partitionColumn).equals("2")) { } @Test - public void testMultiFileTable() + public void testMultiFileTableWithNaNValue() throws Exception { - DeltaLakeColumnHandle columnHandle = new DeltaLakeColumnHandle("name", createUnboundedVarcharType(), "name", createUnboundedVarcharType(), REGULAR); - Session session = testSessionBuilder() - .setCatalog(DELTA_CATALOG) - .setSystemProperty("scale_writers", "false") - .setSchema(SCHEMA) - .build(); + String columnName = "key"; + DeltaLakeColumnHandle columnHandle = new DeltaLakeColumnHandle(columnName, DoubleType.DOUBLE, columnName, DoubleType.DOUBLE, REGULAR); try (TestTable table = new TestTable( - "test_partitioned_table_", - ImmutableList.of(), + "test_multi_file_table_nan_value_", + ImmutableList.of(columnName), ImmutableList.of(), - "SELECT name FROM tpch.tiny.nation UNION select name from tpch.tiny.customer", - session)) { + "SELECT IF(custkey = 1143, nan(), CAST(custkey AS double)) FROM tpch.tiny.customer")) { + assertUpdate("INSERT INTO %s SELECT CAST(nationkey AS double) FROM tpch.tiny.nation".formatted(table.getName()), 25); List addFileEntries = getAddFileEntries(table.getName()); assertThat(addFileEntries.size()).isGreaterThan(1); List statistics = addFileEntries.stream().map(entry -> entry.getStats().get()).collect(toImmutableList()); - List minValues = statistics.stream().map(stat -> stat.getMinColumnValue(columnHandle).get()).collect(toImmutableList()); - List maxValues = statistics.stream().map(stat -> stat.getMaxColumnValue(columnHandle).get()).collect(toImmutableList()); - - // All values in the table are distinct, so the min and max values should all be different - assertEquals(minValues.size(), minValues.stream().distinct().count()); - assertEquals(maxValues.size(), maxValues.stream().distinct().count()); + assertEquals(statistics.stream().filter(stat -> stat.getMinColumnValue(columnHandle).isEmpty() && stat.getMaxColumnValue(columnHandle).isEmpty()).count(), 1); + assertEquals( + statistics.stream().filter(stat -> stat.getMinColumnValue(columnHandle).isPresent() && stat.getMaxColumnValue(columnHandle).isPresent()).count(), + statistics.size() - 1); } } - @Test - public void testMultiFileTableWithNaNValue() - throws Exception - { - // assertEventually because sometimes write from tpch.tiny.orders creates one file only and the test requires at least two files - assertEventually(() -> { - String columnName = "orderkey"; - DeltaLakeColumnHandle columnHandle = new DeltaLakeColumnHandle(columnName, DoubleType.DOUBLE, columnName, DoubleType.DOUBLE, REGULAR); - Session session = testSessionBuilder() - .setCatalog(DELTA_CATALOG) - .setSchema(SCHEMA) - .setSystemProperty("scale_writers", "false") - .build(); - try (TestTable table = new TestTable( - "test_partitioned_table_", - ImmutableList.of(columnName), - ImmutableList.of(), - "SELECT IF(orderkey = 50597, nan(), CAST(orderkey AS double)) FROM tpch.tiny.orders", - session)) { - List addFileEntries = getAddFileEntries(table.getName()); - assertThat(addFileEntries.size()).isGreaterThan(1); - - List statistics = addFileEntries.stream().map(entry -> entry.getStats().get()).collect(toImmutableList()); - - assertEquals(statistics.stream().filter(stat -> stat.getMinColumnValue(columnHandle).isEmpty() && stat.getMaxColumnValue(columnHandle).isEmpty()).count(), 1); - assertEquals( - statistics.stream().filter(stat -> stat.getMinColumnValue(columnHandle).isPresent() && stat.getMaxColumnValue(columnHandle).isPresent()).count(), - statistics.size() - 1); - } - }); - } - protected class TestTable implements AutoCloseable { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java index 5288f7f1c3ea..99d8addc7fdd 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java @@ -1111,6 +1111,77 @@ public void testCheckpointing() assertUpdate("DROP TABLE " + tableName); } + @Test(dataProvider = "testCheckpointWriteStatsAsStructDataProvider") + public void testCheckpointWriteStatsAsStruct(String type, String inputValue, String nullsFraction, String statsValue) + { + verifySupportsInsert(); + + String tableName = "test_checkpoint_write_stats_as_struct_" + randomTableSuffix(); + + // Set 'checkpoint_interval' as 1 to write 'stats_parsed' field every INSERT + assertUpdate( + format("CREATE TABLE %s (col %s) WITH (location = '%s', checkpoint_interval = 1)", + tableName, + type, + getLocationForTable(bucketName, tableName))); + assertUpdate("INSERT INTO " + tableName + " SELECT " + inputValue, 1); + + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('col', null, null, " + nullsFraction + ", null, " + statsValue + ", " + statsValue + ")," + + "(null, null, null, null, 1.0, null, null)"); + + assertUpdate("DROP TABLE " + tableName); + } + + @DataProvider + public Object[][] testCheckpointWriteStatsAsStructDataProvider() + { + return new Object[][] { + {"boolean", "true", "0.0", "null"}, + {"integer", "1", "0.0", "1"}, + {"tinyint", "2", "0.0", "2"}, + {"smallint", "3", "0.0", "3"}, + {"bigint", "1000", "0.0", "1000"}, + {"real", "0.1", "0.0", "0.1"}, + {"double", "1.0", "0.0", "1.0"}, + {"decimal(3,2)", "3.14", "0.0", "3.14"}, + {"decimal(30,1)", "12345", "0.0", "12345.0"}, + {"varchar", "'test'", "0.0", "null"}, + {"varbinary", "X'65683F'", "0.0", "null"}, + {"date", "date '2021-02-03'", "0.0", "'2021-02-03'"}, + {"timestamp(3) with time zone", "timestamp '2001-08-22 03:04:05.321 -08:00'", "0.0", "'2001-08-22 11:04:05.321 UTC'"}, + {"array(int)", "array[1]", "null", "null"}, + {"map(varchar,int)", "map(array['foo', 'bar'], array[1, 2])", "null", "null"}, + {"row(x bigint)", "cast(row(1) as row(x bigint))", "null", "null"}, + }; + } + + @Test + public void testCheckpointWriteStatsAsStructWithPartiallyUnsupportedColumnStats() + { + verifySupportsInsert(); + + String tableName = "test_checkpoint_write_stats_as_struct_partially_unsupported_" + randomTableSuffix(); + + // Column statistics on boolean column is unsupported + assertUpdate( + format("CREATE TABLE %s (col integer, unsupported boolean) WITH (location = '%s', checkpoint_interval = 1)", + tableName, + getLocationForTable(bucketName, tableName))); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, true)", 1); + + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('col', null, null, 0.0, null, 1, 1)," + + "('unsupported', null, null, 0.0, null, null, null)," + + "(null, null, null, null, 1.0, null, null)"); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testDeltaLakeTableLocationChanged() throws Exception diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/DeltaLakeQueryRunner.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/DeltaLakeQueryRunner.java index c0b7f37784e0..3295ff8febd8 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/DeltaLakeQueryRunner.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/DeltaLakeQueryRunner.java @@ -16,14 +16,12 @@ import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; import io.trino.Session; -import io.trino.metadata.QualifiedObjectName; import io.trino.plugin.hive.containers.HiveHadoop; import io.trino.plugin.hive.containers.HiveMinioDataLake; import io.trino.plugin.tpch.TpchPlugin; import io.trino.testing.DistributedQueryRunner; import io.trino.testing.QueryRunner; import io.trino.tpch.TpchTable; -import org.intellij.lang.annotations.Language; import java.nio.file.Path; import java.util.HashMap; @@ -32,17 +30,14 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.testing.Closeables.closeAllSuppress; -import static io.airlift.units.Duration.nanosSince; import static io.trino.plugin.deltalake.DeltaLakeConnectorFactory.CONNECTOR_NAME; import static io.trino.plugin.hive.containers.HiveMinioDataLake.MINIO_ACCESS_KEY; import static io.trino.plugin.hive.containers.HiveMinioDataLake.MINIO_SECRET_KEY; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.QueryAssertions.copyTpchTables; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; -import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.assertj.core.api.Assertions.assertThat; import static org.testng.util.Strings.isNullOrEmpty; public final class DeltaLakeQueryRunner @@ -233,42 +228,6 @@ private static String requiredNonEmptySystemProperty(String propertyName) return val; } - private static void copyTpchTables( - QueryRunner queryRunner, - String sourceCatalog, - String sourceSchema, - TableLocationSupplier locationSupplier, - Session session, - Iterable> tables) - { - log.info("Loading data from %s.%s...", sourceCatalog, sourceSchema); - long startTime = System.nanoTime(); - for (TpchTable table : tables) { - copyTable(queryRunner, sourceCatalog, sourceSchema, table.getTableName().toLowerCase(ENGLISH), locationSupplier, session); - } - log.info("Loading from %s.%s complete in %s", sourceCatalog, sourceSchema, nanosSince(startTime).toString(SECONDS)); - } - - private static void copyTable(QueryRunner queryRunner, String sourceCatalog, String sourceSchema, String sourceTable, TableLocationSupplier locationSupplier, Session session) - { - QualifiedObjectName table = new QualifiedObjectName(sourceCatalog, sourceSchema, sourceTable); - copyTable(queryRunner, table, locationSupplier, session); - } - - private static void copyTable(QueryRunner queryRunner, QualifiedObjectName table, TableLocationSupplier locationSupplier, Session session) - { - long start = System.nanoTime(); - log.info("Running import for %s", table.getObjectName()); - String location = locationSupplier.getTableLocation(table.getSchemaName(), table.getObjectName()); - @Language("SQL") String sql = format("CREATE TABLE IF NOT EXISTS %s WITH (location='%s') AS SELECT * FROM %s", table.getObjectName(), location, table); - long rows = (Long) queryRunner.execute(session, sql).getMaterializedRows().get(0).getField(0); - log.info("Imported %s rows for %s in %s", rows, table.getObjectName(), nanosSince(start).convertToMostSuccinctTimeUnit()); - - assertThat(queryRunner.execute(session, "SELECT count(*) FROM " + table).getOnlyValue()) - .as("Table is not loaded properly: %s", table) - .isEqualTo(queryRunner.execute(session, "SELECT count(*) FROM " + table.getObjectName()).getOnlyValue()); - } - private static Session createSession() { return testSessionBuilder() @@ -277,12 +236,6 @@ private static Session createSession() .build(); } - @FunctionalInterface - private interface TableLocationSupplier - { - String getTableLocation(String schemaName, String tableName); - } - public static class DefaultDeltaLakeQueryRunnerMain { public static void main(String[] args) @@ -294,7 +247,7 @@ public static void main(String[] args) ImmutableMap.of("delta.enable-non-concurrent-writes", "true")); Path baseDirectory = queryRunner.getCoordinator().getBaseDataDir().resolve(DELTA_CATALOG); - copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, (schemaName, tableName) -> format("file://%s/%s/%s", baseDirectory, schemaName, tableName), createSession(), TpchTable.getTables()); + copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, createSession(), TpchTable.getTables()); log.info("Data directory is: %s", baseDirectory); Thread.sleep(10); @@ -341,8 +294,8 @@ public static void main(String[] args) hiveMinioDataLake.getHiveHadoop(), runner -> {}); - queryRunner.execute("CREATE SCHEMA tpch"); - copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, (schemaName, tableName) -> format("s3://%s/%s/%s", bucketName, schemaName, tableName), createSession(), TpchTable.getTables()); + queryRunner.execute("CREATE SCHEMA tpch WITH (location='s3://" + bucketName + "/tpch')"); + copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, createSession(), TpchTable.getTables()); Thread.sleep(10); Logger log = Logger.get(DeltaLakeQueryRunner.class); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePlugin.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePlugin.java index 19c21f511681..2a182ad484a2 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePlugin.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePlugin.java @@ -16,11 +16,16 @@ import com.google.common.collect.ImmutableMap; import io.airlift.bootstrap.ApplicationConfigurationException; import io.trino.spi.Plugin; +import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorFactory; import io.trino.testing.TestingConnectorContext; import org.testng.annotations.Test; +import java.io.File; +import java.nio.file.Files; + import static com.google.common.collect.Iterables.getOnlyElement; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestDeltaLakePlugin @@ -136,4 +141,58 @@ public void testNoActiveDataFilesCaching() "delta.metadata.live-files.cache-ttl", "0s"), new TestingConnectorContext()); } + + @Test + public void testReadOnlyAllAccessControl() + { + Plugin plugin = new DeltaLakePlugin(); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + + factory.create( + "test", + ImmutableMap.builder() + .put("hive.metastore.uri", "thrift://foo:1234") + .put("delta.security", "read-only") + .buildOrThrow(), + new TestingConnectorContext()) + .shutdown(); + } + + @Test + public void testSystemAccessControl() + { + Plugin plugin = new DeltaLakePlugin(); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + + Connector connector = factory.create( + "test", + ImmutableMap.builder() + .put("hive.metastore.uri", "thrift://foo:1234") + .put("delta.security", "system") + .buildOrThrow(), + new TestingConnectorContext()); + assertThatThrownBy(connector::getAccessControl).isInstanceOf(UnsupportedOperationException.class); + connector.shutdown(); + } + + @Test + public void testFileBasedAccessControl() + throws Exception + { + Plugin plugin = new DeltaLakePlugin(); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + File tempFile = File.createTempFile("test-delta-lake-plugin-access-control", ".json"); + tempFile.deleteOnExit(); + Files.write(tempFile.toPath(), "{}".getBytes(UTF_8)); + + factory.create( + "test", + ImmutableMap.builder() + .put("hive.metastore.uri", "thrift://foo:1234") + .put("delta.security", "file") + .put("security.config-file", tempFile.getAbsolutePath()) + .buildOrThrow(), + new TestingConnectorContext()) + .shutdown(); + } } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSecurityConfig.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSecurityConfig.java new file mode 100644 index 000000000000..dde1c856ddd4 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSecurityConfig.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity.ALLOW_ALL; +import static io.trino.plugin.deltalake.DeltaLakeSecurityConfig.DeltaLakeSecurity.READ_ONLY; + +public class TestDeltaLakeSecurityConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(DeltaLakeSecurityConfig.class) + .setSecuritySystem(ALLOW_ALL)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder() + .put("delta.security", "read-only") + .buildOrThrow(); + + DeltaLakeSecurityConfig expected = new DeltaLakeSecurityConfig() + .setSecuritySystem(READ_ONLY); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/metastore/glue/TestDeltaLakeRenameToWithGlueMetastore.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/metastore/glue/TestDeltaLakeRenameToWithGlueMetastore.java index 499acd18a4a5..767250efeaca 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/metastore/glue/TestDeltaLakeRenameToWithGlueMetastore.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/metastore/glue/TestDeltaLakeRenameToWithGlueMetastore.java @@ -13,20 +13,9 @@ */ package io.trino.plugin.deltalake.metastore.glue; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.trino.Session; -import io.trino.hdfs.DynamicHdfsConfiguration; -import io.trino.hdfs.HdfsConfig; -import io.trino.hdfs.HdfsConfigurationInitializer; -import io.trino.hdfs.HdfsEnvironment; -import io.trino.hdfs.authentication.NoHdfsAuthentication; -import io.trino.plugin.deltalake.DeltaLakePlugin; -import io.trino.plugin.hive.metastore.HiveMetastore; -import io.trino.plugin.hive.metastore.glue.DefaultGlueColumnStatisticsProviderFactory; -import io.trino.plugin.hive.metastore.glue.GlueHiveMetastore; -import io.trino.plugin.hive.metastore.glue.GlueHiveMetastoreConfig; +import io.trino.plugin.deltalake.DeltaLakeQueryRunner; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.DistributedQueryRunner; import io.trino.testing.QueryRunner; @@ -34,10 +23,7 @@ import org.testng.annotations.Test; import java.io.File; -import java.util.Optional; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.trino.plugin.deltalake.DeltaLakeConnectorFactory.CONNECTOR_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; @@ -47,9 +33,8 @@ public class TestDeltaLakeRenameToWithGlueMetastore { protected static final String SCHEMA = "test_delta_lake_rename_to_with_glue_" + randomTableSuffix(); protected static final String CATALOG_NAME = "test_delta_lake_rename_to_with_glue"; - protected File metastoreDir; - protected HiveMetastore metastore; - protected HdfsEnvironment hdfsEnvironment; + + private File schemaLocation; @Override protected QueryRunner createQueryRunner() @@ -60,38 +45,13 @@ protected QueryRunner createQueryRunner() .setSchema(SCHEMA) .build(); - DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(deltaLakeSession).build(); - - this.metastoreDir = new File(queryRunner.getCoordinator().getBaseDataDir().resolve("delta_lake_data").toString()); - this.metastoreDir.deleteOnExit(); - - queryRunner.installPlugin(new DeltaLakePlugin()); - queryRunner.createCatalog( - CATALOG_NAME, - CONNECTOR_NAME, - ImmutableMap.builder() - .put("hive.metastore", "glue") - .put("hive.metastore.glue.region", "us-east-2") - .put("hive.metastore.glue.default-warehouse-dir", metastoreDir.getPath()) - .buildOrThrow()); - - HdfsConfig hdfsConfig = new HdfsConfig(); - hdfsEnvironment = new HdfsEnvironment( - new DynamicHdfsConfiguration(new HdfsConfigurationInitializer(hdfsConfig), ImmutableSet.of()), - hdfsConfig, - new NoHdfsAuthentication()); - GlueHiveMetastoreConfig glueConfig = new GlueHiveMetastoreConfig() - .setGlueRegion("us-east-2"); - metastore = new GlueHiveMetastore( - hdfsEnvironment, - glueConfig, - DefaultAWSCredentialsProviderChain.getInstance(), - directExecutor(), - new DefaultGlueColumnStatisticsProviderFactory(directExecutor(), directExecutor()), - Optional.empty(), - table -> true); - - queryRunner.execute("CREATE SCHEMA " + SCHEMA + " WITH (location = '" + metastoreDir.getPath() + "')"); + DistributedQueryRunner queryRunner = DeltaLakeQueryRunner.builder(deltaLakeSession) + .setCatalogName(CATALOG_NAME) + .setDeltaProperties(ImmutableMap.of("hive.metastore", "glue")) + .build(); + schemaLocation = new File(queryRunner.getCoordinator().getBaseDataDir().resolve("delta_lake_data").toString()); + schemaLocation.deleteOnExit(); + queryRunner.execute("CREATE SCHEMA " + SCHEMA + " WITH (location = '" + schemaLocation.getPath() + "')"); return queryRunner; } @@ -100,7 +60,7 @@ public void testRenameOfExternalTable() { String oldTable = "test_table_external_to_be_renamed_" + randomTableSuffix(); String newTable = "test_table_external_renamed_" + randomTableSuffix(); - String location = metastoreDir.getAbsolutePath() + "/tableLocation/"; + String location = schemaLocation.getPath() + "/tableLocation/"; try { assertUpdate(format("CREATE TABLE %s WITH (location = '%s') AS SELECT 1 AS val ", oldTable, location), 1); String oldLocation = (String) computeScalar("SELECT \"$path\" FROM " + oldTable); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java index c4f524bccb5c..b2563c37460e 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointWriter.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoInputFile; import io.trino.filesystem.hdfs.HdfsFileSystemFactory; @@ -82,7 +83,7 @@ public void setUp() } @Test - public void testCheckpointWriteReadRoundtrip() + public void testCheckpointWriteReadJsonRoundtrip() throws IOException { MetadataEntry metadataEntry = new MetadataEntry( @@ -114,6 +115,7 @@ public void testCheckpointWriteReadRoundtrip() "{\"name\":\"s2\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]},\"nullable\":true,\"metadata\":{}}]}", ImmutableList.of("part_key"), ImmutableMap.of( + "delta.checkpoint.writeStatsAsStruct", "false", "configOption1", "blah", "configOption2", "plah"), 1000); @@ -176,6 +178,76 @@ public void testCheckpointWriteReadRoundtrip() "someTag", "someValue", "otherTag", "otherValue")); + RemoveFileEntry removeFileEntry = new RemoveFileEntry( + "removeFilePath", + 1000, + true); + + CheckpointEntries entries = new CheckpointEntries( + metadataEntry, + protocolEntry, + ImmutableSet.of(transactionEntry), + ImmutableSet.of(addFileEntryJsonStats), + ImmutableSet.of(removeFileEntry)); + + CheckpointWriter writer = new CheckpointWriter(typeManager, checkpointSchemaManager, HDFS_ENVIRONMENT); + + File targetFile = File.createTempFile("testCheckpointWriteReadRoundtrip-", ".checkpoint.parquet"); + targetFile.deleteOnExit(); + + Path targetPath = new Path("file://" + targetFile.getAbsolutePath()); + targetFile.delete(); // file must not exist when writer is called + writer.write(SESSION, entries, targetPath); + + CheckpointEntries readEntries = readCheckpoint(targetPath, metadataEntry, true); + assertEquals(readEntries.getTransactionEntries(), entries.getTransactionEntries()); + assertEquals(readEntries.getRemoveFileEntries(), entries.getRemoveFileEntries()); + assertEquals(readEntries.getMetadataEntry(), entries.getMetadataEntry()); + assertEquals(readEntries.getProtocolEntry(), entries.getProtocolEntry()); + assertEquals( + readEntries.getAddFileEntries().stream().map(this::makeComparable).collect(toImmutableSet()), + entries.getAddFileEntries().stream().map(this::makeComparable).collect(toImmutableSet())); + } + + @Test + public void testCheckpointWriteReadParquetStatisticsRoundtrip() + throws IOException + { + MetadataEntry metadataEntry = new MetadataEntry( + "metadataId", + "metadataName", + "metadataDescription", + new MetadataEntry.Format( + "metadataFormatProvider", + ImmutableMap.of( + "formatOptionX", "blah", + "fomatOptionY", "plah")), + "{\"type\":\"struct\",\"fields\":" + + "[{\"name\":\"ts\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"str\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"dec_short\",\"type\":\"decimal(5,1)\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"dec_long\",\"type\":\"decimal(25,3)\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"l\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"in\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"sh\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"byt\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"fl\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"dou\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"bool\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"bin\",\"type\":\"binary\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"dat\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"arr\",\"type\":{\"type\":\"array\",\"elementType\":\"integer\",\"containsNull\":true},\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"m\",\"type\":{\"type\":\"map\",\"keyType\":\"integer\",\"valueType\":\"string\",\"valueContainsNull\":true},\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"row\",\"type\":{\"type\":\"struct\",\"fields\":[{\"name\":\"s1\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"s2\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]},\"nullable\":true,\"metadata\":{}}]}", + ImmutableList.of("part_key"), + ImmutableMap.of( + "configOption1", "blah", + "configOption2", "plah"), + 1000); + ProtocolEntry protocolEntry = new ProtocolEntry(10, 20); + TransactionEntry transactionEntry = new TransactionEntry("appId", 1, 1001); + Block[] minMaxRowFieldBlocks = new Block[]{ nativeValueToBlock(IntegerType.INTEGER, 1L), nativeValueToBlock(createUnboundedVarcharType(), utf8Slice("a")) @@ -251,7 +323,7 @@ public void testCheckpointWriteReadRoundtrip() metadataEntry, protocolEntry, ImmutableSet.of(transactionEntry), - ImmutableSet.of(addFileEntryJsonStats, addFileEntryParquetStats), + ImmutableSet.of(addFileEntryParquetStats), ImmutableSet.of(removeFileEntry)); CheckpointWriter writer = new CheckpointWriter(typeManager, checkpointSchemaManager, HDFS_ENVIRONMENT); @@ -395,6 +467,9 @@ private Optional> makeComparableStatistics(Optional io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-elasticsearch/pom.xml b/plugin/trino-elasticsearch/pom.xml index 856daa58fb12..28a8e11c4c75 100644 --- a/plugin/trino-elasticsearch/pom.xml +++ b/plugin/trino-elasticsearch/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java index b5fb548fddef..9db12cc86729 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java @@ -95,7 +95,7 @@ import static io.trino.plugin.elasticsearch.ElasticsearchTableHandle.Type.SCAN; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -599,7 +599,7 @@ public Optional> applyFilter(C protected static boolean isSupportedLikeCall(Call call) { - if (!LIKE_PATTERN_FUNCTION_NAME.equals(call.getFunctionName())) { + if (!LIKE_FUNCTION_NAME.equals(call.getFunctionName())) { return false; } diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java index 60358786e679..86c24f572c18 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java @@ -100,13 +100,12 @@ private static void addPredicateToQueryBuilder(BoolQueryBuilder queryBuilder, St queryBuilder.filter(getOnlyElement(shouldClauses)); return; } - else if (shouldClauses.size() > 1) { + if (shouldClauses.size() > 1) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); shouldClauses.forEach(boolQueryBuilder::should); queryBuilder.filter(boolQueryBuilder); return; } - return; } private static List getShouldClauses(String columnName, Domain domain, Type type) diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchSplitManager.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchSplitManager.java index 18fce4cba83e..52c42262b08d 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchSplitManager.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchSplitManager.java @@ -57,12 +57,10 @@ public ConnectorSplitSource getSplits( if (tableHandle.getType().equals(QUERY)) { return new FixedSplitSource(ImmutableList.of(new ElasticsearchSplit(tableHandle.getIndex(), 0, Optional.empty()))); } - else { - List splits = client.getSearchShards(tableHandle.getIndex()).stream() - .map(shard -> new ElasticsearchSplit(shard.getIndex(), shard.getId(), shard.getAddress())) - .collect(toImmutableList()); + List splits = client.getSearchShards(tableHandle.getIndex()).stream() + .map(shard -> new ElasticsearchSplit(shard.getIndex(), shard.getId(), shard.getAddress())) + .collect(toImmutableList()); - return new FixedSplitSource(splits); - } + return new FixedSplitSource(splits); } } diff --git a/plugin/trino-example-http/pom.xml b/plugin/trino-example-http/pom.xml index 688a0de4a3a1..bbfa3d500dd5 100644 --- a/plugin/trino-example-http/pom.xml +++ b/plugin/trino-example-http/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-exchange-filesystem/pom.xml b/plugin/trino-exchange-filesystem/pom.xml index cfee88f611f3..0b87f4030451 100644 --- a/plugin/trino-exchange-filesystem/pom.xml +++ b/plugin/trino-exchange-filesystem/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java index 982fc6357f05..0676b5138b08 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java @@ -78,7 +78,7 @@ public class FileSystemExchange private final long exchangeSourceHandleTargetDataSizeInBytes; private final ExecutorService executor; - private final Map randomizedPrefixes = new ConcurrentHashMap<>(); + private final Map outputDirectories = new ConcurrentHashMap<>(); @GuardedBy("this") private final Set allSinks = new HashSet<>(); @@ -270,12 +270,10 @@ private ListenableFuture> getCommittedPartitions(i private URI getTaskOutputDirectory(int taskPartitionId) { - URI baseDirectory = baseDirectories.get(taskPartitionId % baseDirectories.size()); - String randomizedHexPrefix = randomizedPrefixes.computeIfAbsent(taskPartitionId, ignored -> generateRandomizedHexPrefix()); - // Add a randomized prefix to evenly distribute data into different S3 shards // Data output file path format: {randomizedHexPrefix}.{queryId}.{stageId}.{sinkPartitionId}/{attemptId}/{sourcePartitionId}_{splitId}.data - return baseDirectory.resolve(randomizedHexPrefix + "." + exchangeContext.getQueryId() + "." + exchangeContext.getExchangeId() + "." + taskPartitionId + PATH_SEPARATOR); + return outputDirectories.computeIfAbsent(taskPartitionId, ignored -> baseDirectories.get(ThreadLocalRandom.current().nextInt(baseDirectories.size())) + .resolve(generateRandomizedHexPrefix() + "." + exchangeContext.getQueryId() + "." + exchangeContext.getExchangeId() + "." + taskPartitionId + PATH_SEPARATOR)); } @Override diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeFutures.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeFutures.java index d8f1f81f59dc..2fa01a970d60 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeFutures.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeFutures.java @@ -33,9 +33,7 @@ public static ListenableFuture translateFailures(ListenableFuture liste if (throwable instanceof Error || throwable instanceof IOException) { return immediateFailedFuture(throwable); } - else { - return immediateFailedFuture(new IOException(throwable)); - } + return immediateFailedFuture(new IOException(throwable)); }, directExecutor())); } } diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSinkHandle.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSinkHandle.java index 0a3588ae6362..4308f7f5baca 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSinkHandle.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSinkHandle.java @@ -64,9 +64,7 @@ public boolean equals(Object o) if (secretKey.isPresent() && that.secretKey.isPresent()) { return partitionId == that.getPartitionId() && Arrays.equals(secretKey.get(), that.secretKey.get()); } - else { - return partitionId == that.getPartitionId() && secretKey.isEmpty() && that.secretKey.isEmpty(); - } + return partitionId == that.getPartitionId() && secretKey.isEmpty() && that.secretKey.isEmpty(); } @Override diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java index 7b175d85c0a1..79d4391092b9 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java @@ -20,9 +20,7 @@ import io.trino.spi.exchange.ExchangeSourceHandle; import org.openjdk.jol.info.ClassLayout; -import java.util.Arrays; import java.util.List; -import java.util.Objects; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -85,30 +83,6 @@ public Optional getSecretKey() return secretKey; } - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - FileSystemExchangeSourceHandle that = (FileSystemExchangeSourceHandle) o; - if (secretKey.isPresent() && that.secretKey.isPresent()) { - return partitionId == that.getPartitionId() && Arrays.equals(secretKey.get(), that.secretKey.get()); - } - else { - return partitionId == that.getPartitionId() && secretKey.isEmpty() && that.secretKey.isEmpty(); - } - } - - @Override - public int hashCode() - { - return Objects.hash(partitionId, files, secretKey); - } - @Override public String toString() { diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/local/LocalFileSystemExchangeStorage.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/local/LocalFileSystemExchangeStorage.java index 631bdfb713c0..3a01955bd2b5 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/local/LocalFileSystemExchangeStorage.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/local/LocalFileSystemExchangeStorage.java @@ -229,9 +229,7 @@ private InputStreamSliceInput getSliceInput(ExchangeSourceFile sourceFile) throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create CipherInputStream: " + e.getMessage(), e); } } - else { - return new InputStreamSliceInput(new FileInputStream(file), BUFFER_SIZE_IN_BYTES); - } + return new InputStreamSliceInput(new FileInputStream(file), BUFFER_SIZE_IN_BYTES); } } diff --git a/plugin/trino-geospatial/pom.xml b/plugin/trino-geospatial/pom.xml index a5c6f3fd8ed4..910af4ccd009 100644 --- a/plugin/trino-geospatial/pom.xml +++ b/plugin/trino-geospatial/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-google-sheets/pom.xml b/plugin/trino-google-sheets/pom.xml index 513c2309fdc5..e68be224a215 100644 --- a/plugin/trino-google-sheets/pom.xml +++ b/plugin/trino-google-sheets/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsClient.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsClient.java index ab9014f314fc..d40c05ef6136 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsClient.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsClient.java @@ -158,11 +158,9 @@ public Set getTableNames() public List> readAllValues(String tableName) { try { - Optional sheetExpression = tableSheetMappingCache.getUnchecked(tableName); - if (sheetExpression.isEmpty()) { - throw new TrinoException(SHEETS_UNKNOWN_TABLE_ERROR, "Sheet expression not found for table " + tableName); - } - return sheetDataCache.getUnchecked(sheetExpression.get()); + String sheetExpression = tableSheetMappingCache.getUnchecked(tableName) + .orElseThrow(() -> new TrinoException(SHEETS_UNKNOWN_TABLE_ERROR, "Sheet expression not found for table " + tableName)); + return sheetDataCache.getUnchecked(sheetExpression); } catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java index e9e38d02cc0a..412e280dc7a8 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java @@ -80,25 +80,21 @@ public SheetsTableHandle getTableHandle(ConnectorSession session, SchemaTableNam @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) { - Optional connectorTableMetadata = getTableMetadata(((SheetsTableHandle) table).toSchemaTableName()); - if (connectorTableMetadata.isEmpty()) { - throw new TrinoException(SHEETS_UNKNOWN_TABLE_ERROR, "Metadata not found for table " + ((SheetsTableHandle) table).getTableName()); - } - return connectorTableMetadata.get(); + SheetsTableHandle tableHandle = (SheetsTableHandle) table; + return getTableMetadata(tableHandle.toSchemaTableName()) + .orElseThrow(() -> new TrinoException(SHEETS_UNKNOWN_TABLE_ERROR, "Metadata not found for table " + tableHandle.getTableName())); } @Override public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { SheetsTableHandle sheetsTableHandle = (SheetsTableHandle) tableHandle; - Optional table = sheetsClient.getTable(sheetsTableHandle.getTableName()); - if (table.isEmpty()) { - throw new TableNotFoundException(sheetsTableHandle.toSchemaTableName()); - } + SheetsTable table = sheetsClient.getTable(sheetsTableHandle.getTableName()) + .orElseThrow(() -> new TableNotFoundException(sheetsTableHandle.toSchemaTableName())); ImmutableMap.Builder columnHandles = ImmutableMap.builder(); int index = 0; - for (ColumnMetadata column : table.get().getColumnsMetadata()) { + for (ColumnMetadata column : table.getColumnsMetadata()) { columnHandles.put(column.getName(), new SheetsColumnHandle(column.getName(), column.getType(), index)); index++; } diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsSplitManager.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsSplitManager.java index 5f763aa16d9d..48d956ca4b32 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsSplitManager.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsSplitManager.java @@ -29,7 +29,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -53,15 +52,12 @@ public ConnectorSplitSource getSplits( Constraint constraint) { SheetsTableHandle tableHandle = (SheetsTableHandle) connectorTableHandle; - Optional table = sheetsClient.getTable(tableHandle.getTableName()); - - // this can happen if table is removed during a query - if (table.isEmpty()) { - throw new TableNotFoundException(tableHandle.toSchemaTableName()); - } + SheetsTable table = sheetsClient.getTable(tableHandle.getTableName()) + // this can happen if table is removed during a query + .orElseThrow(() -> new TableNotFoundException(tableHandle.toSchemaTableName())); List splits = new ArrayList<>(); - splits.add(new SheetsSplit(tableHandle.getSchemaName(), tableHandle.getTableName(), table.get().getValues())); + splits.add(new SheetsSplit(tableHandle.getSchemaName(), tableHandle.getTableName(), table.getValues())); Collections.shuffle(splits); return new FixedSplitSource(splits); } diff --git a/plugin/trino-hive-hadoop2/pom.xml b/plugin/trino-hive-hadoop2/pom.xml index a46caf1cb549..6bbe3b2f8d70 100644 --- a/plugin/trino-hive-hadoop2/pom.xml +++ b/plugin/trino-hive-hadoop2/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/s3select/S3SelectTestHelper.java b/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/s3select/S3SelectTestHelper.java index d0d82d7e0db5..f318474c63c4 100644 --- a/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/s3select/S3SelectTestHelper.java +++ b/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/s3select/S3SelectTestHelper.java @@ -48,7 +48,6 @@ import io.trino.plugin.hive.fs.FileSystemDirectoryLister; import io.trino.plugin.hive.metastore.HiveMetastoreConfig; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; -import io.trino.plugin.hive.metastore.MetastoreTypeConfig; import io.trino.plugin.hive.metastore.thrift.BridgingHiveMetastore; import io.trino.plugin.hive.s3.HiveS3Config; import io.trino.plugin.hive.s3.TrinoS3ConfigurationInitializer; @@ -152,7 +151,7 @@ public S3SelectTestHelper(String host, SqlStandardAccessControlMetadata::new, new FileSystemDirectoryLister(), new PartitionProjectionService(this.hiveConfig, ImmutableMap.of(), new TestingTypeManager()), - new MetastoreTypeConfig()); + true); transactionManager = new HiveTransactionManager(metadataFactory); splitManager = new HiveSplitManager( diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index 288589f383d0..9bbaa3854877 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeAccessControlMetadataFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AllowHiveTableRename.java similarity index 54% rename from plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeAccessControlMetadataFactory.java rename to plugin/trino-hive/src/main/java/io/trino/plugin/hive/AllowHiveTableRename.java index 84fd80e519ca..12d6b7c44795 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeAccessControlMetadataFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AllowHiveTableRename.java @@ -11,20 +11,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.deltalake; +package io.trino.plugin.hive; -import io.trino.plugin.hive.metastore.HiveMetastore; -import io.trino.plugin.hive.security.AccessControlMetadata; +import javax.inject.Qualifier; -public interface DeltaLakeAccessControlMetadataFactory -{ - DeltaLakeAccessControlMetadataFactory SYSTEM = metastore -> new AccessControlMetadata() { - @Override - public boolean isUsingSystemSecurity() - { - return true; - } - }; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; - AccessControlMetadata create(HiveMetastore metastore); -} +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({FIELD, PARAMETER, METHOD}) +@Qualifier +public @interface AllowHiveTableRename {} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveApplyProjectionUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveApplyProjectionUtil.java index 487bfb56b93f..d031c3b1cd91 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveApplyProjectionUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveApplyProjectionUtil.java @@ -72,7 +72,7 @@ public static ProjectedColumnRepresentation createProjectedColumnRepresentation( target = (Variable) expression; break; } - else if (expression instanceof FieldDereference) { + if (expression instanceof FieldDereference) { FieldDereference dereference = (FieldDereference) expression; ordinals.add(dereference.getField()); expression = dereference.getTarget(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveLocationService.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveLocationService.java index 54878a5ee300..18e7d2e468d9 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveLocationService.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveLocationService.java @@ -69,9 +69,7 @@ public LocationHandle forNewTable(SemiTransactionalHiveMetastore metastore, Conn Path writePath = createTemporaryPath(session, context, hdfsEnvironment, targetPath); return new LocationHandle(targetPath, writePath, false, STAGE_AND_MOVE_TO_TARGET_DIRECTORY); } - else { - return new LocationHandle(targetPath, targetPath, false, DIRECT_TO_TARGET_NEW_DIRECTORY); - } + return new LocationHandle(targetPath, targetPath, false, DIRECT_TO_TARGET_NEW_DIRECTORY); } @Override @@ -84,9 +82,7 @@ public LocationHandle forExistingTable(SemiTransactionalHiveMetastore metastore, Path writePath = createTemporaryPath(session, context, hdfsEnvironment, targetPath); return new LocationHandle(targetPath, writePath, true, STAGE_AND_MOVE_TO_TARGET_DIRECTORY); } - else { - return new LocationHandle(targetPath, targetPath, true, DIRECT_TO_TARGET_EXISTING_DIRECTORY); - } + return new LocationHandle(targetPath, targetPath, true, DIRECT_TO_TARGET_EXISTING_DIRECTORY); } @Override @@ -133,13 +129,11 @@ public WriteInfo getPartitionWriteInfo(LocationHandle locationHandle, Optional

table = metastore.getTable(handle.getSchemaName(), handle.getTableName()); - if (table.isEmpty()) { - throw new TableNotFoundException(handle.getSchemaTableName()); - } + Table table = metastore.getTable(handle.getSchemaName(), handle.getTableName()) + .orElseThrow(() -> new TableNotFoundException(handle.getSchemaTableName())); - if (table.get().getPartitionColumns().isEmpty()) { + if (table.getPartitionColumns().isEmpty()) { metastore.truncateUnpartitionedTable(session, handle.getSchemaName(), handle.getTableName()); } else { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadataFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadataFactory.java index b53d2c9b4daa..2c3dc07ce43a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadataFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadataFactory.java @@ -23,7 +23,6 @@ import io.trino.plugin.hive.fs.TransactionScopeCachingDirectoryLister; import io.trino.plugin.hive.metastore.HiveMetastoreConfig; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; -import io.trino.plugin.hive.metastore.MetastoreTypeConfig; import io.trino.plugin.hive.metastore.SemiTransactionalHiveMetastore; import io.trino.plugin.hive.security.AccessControlMetadataFactory; import io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider; @@ -76,7 +75,7 @@ public class HiveMetadataFactory private final DirectoryLister directoryLister; private final long perTransactionFileStatusCacheMaximumSize; private final PartitionProjectionService partitionProjectionService; - private final String metastoreType; + private final boolean allowTableRename; @Inject public HiveMetadataFactory( @@ -99,7 +98,7 @@ public HiveMetadataFactory( AccessControlMetadataFactory accessControlMetadataFactory, DirectoryLister directoryLister, PartitionProjectionService partitionProjectionService, - MetastoreTypeConfig metastoreTypeConfig) + @AllowHiveTableRename boolean allowTableRename) { this( catalogName, @@ -133,7 +132,7 @@ public HiveMetadataFactory( directoryLister, hiveConfig.getPerTransactionFileStatusCacheMaximumSize(), partitionProjectionService, - metastoreTypeConfig); + allowTableRename); } public HiveMetadataFactory( @@ -168,7 +167,7 @@ public HiveMetadataFactory( DirectoryLister directoryLister, long perTransactionFileStatusCacheMaximumSize, PartitionProjectionService partitionProjectionService, - MetastoreTypeConfig metastoreTypeConfig) + boolean allowTableRename) { this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.skipDeletionForAlter = skipDeletionForAlter; @@ -208,8 +207,7 @@ public HiveMetadataFactory( this.directoryLister = requireNonNull(directoryLister, "directoryLister is null"); this.perTransactionFileStatusCacheMaximumSize = perTransactionFileStatusCacheMaximumSize; this.partitionProjectionService = requireNonNull(partitionProjectionService, "partitionProjectionService is null"); - requireNonNull(metastoreTypeConfig, "metastoreTypeConfig is null"); - this.metastoreType = requireNonNull(metastoreTypeConfig.getMetastoreType(), "metastoreType is null"); + this.allowTableRename = allowTableRename; } @Override @@ -256,6 +254,6 @@ public TransactionalMetadata create(ConnectorIdentity identity, boolean autoComm accessControlMetadataFactory.create(metastore), directoryLister, partitionProjectionService, - metastoreType); + allowTableRename); } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java index 42845cc78db1..1783838fcdbb 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java @@ -344,17 +344,15 @@ private Iterable getPartitionMetadata(ConnectorSession se TableToPartitionMapping tableToPartitionMapping = getTableToPartitionMapping(session, storageFormat, tableName, partName, tableColumns, partitionColumns); if (bucketProperty.isPresent()) { - Optional partitionBucketProperty = partition.getStorage().getBucketProperty(); - if (partitionBucketProperty.isEmpty()) { - throw new TrinoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( - "Hive table (%s) is bucketed but partition (%s) is not bucketed", - hivePartition.getTableName(), - hivePartition.getPartitionId())); - } + HiveBucketProperty partitionBucketProperty = partition.getStorage().getBucketProperty() + .orElseThrow(() -> new TrinoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( + "Hive table (%s) is bucketed but partition (%s) is not bucketed", + hivePartition.getTableName(), + hivePartition.getPartitionId()))); int tableBucketCount = bucketProperty.get().getBucketCount(); - int partitionBucketCount = partitionBucketProperty.get().getBucketCount(); + int partitionBucketCount = partitionBucketProperty.getBucketCount(); List tableBucketColumns = bucketProperty.get().getBucketedBy(); - List partitionBucketColumns = partitionBucketProperty.get().getBucketedBy(); + List partitionBucketColumns = partitionBucketProperty.getBucketedBy(); if (!tableBucketColumns.equals(partitionBucketColumns) || !isBucketCountCompatible(tableBucketCount, partitionBucketCount)) { throw new TrinoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( "Hive table (%s) bucketing (columns=%s, buckets=%s) is not compatible with partition (%s) bucketing (columns=%s, buckets=%s)", @@ -367,7 +365,7 @@ private Iterable getPartitionMetadata(ConnectorSession se } if (isPropagateTableScanSortingProperties(session)) { List tableSortedColumns = bucketProperty.get().getSortedBy(); - List partitionSortedColumns = partitionBucketProperty.get().getSortedBy(); + List partitionSortedColumns = partitionBucketProperty.getSortedBy(); if (!isSortingCompatible(tableSortedColumns, partitionSortedColumns)) { throw new TrinoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( "Hive table (%s) sorting by %s is not compatible with partition (%s) sorting by %s. This restriction can be avoided by disabling propagate_table_scan_sorting_properties.", diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java index bdd7c970cf89..f7e696c923cb 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java @@ -362,9 +362,7 @@ else if (maxSplitBytes * 2 >= remainingBlockBytes) { // But an extra invocation likely doesn't matter. return new ConnectorSplitBatch(splits, splits.isEmpty() && queues.isFinished()); } - else { - return new ConnectorSplitBatch(splits, false); - } + return new ConnectorSplitBatch(splits, false); }); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java index 3ddba881672a..955f2e3c2cb3 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java @@ -257,9 +257,7 @@ public Page getNextPage() List channels = dependencyChannels.orElseThrow(() -> new IllegalArgumentException("dependencyChannels not present")); return updateProcessor.removeNonDependencyColumns(page, channels); } - else { - return page; - } + return page; } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java index e1db0a850f95..e41ca8276a23 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java @@ -262,11 +262,8 @@ public HiveWriterFactory( writePath = writeInfo.getWritePath(); } else { - Optional

table = pageSinkMetadataProvider.getTable(); - if (table.isEmpty()) { - throw new TrinoException(HIVE_INVALID_METADATA, format("Table '%s.%s' was dropped during insert", schemaName, tableName)); - } - this.table = table.get(); + this.table = pageSinkMetadataProvider.getTable() + .orElseThrow(() -> new TrinoException(HIVE_INVALID_METADATA, format("Table '%s.%s' was dropped during insert", schemaName, tableName))); writePath = locationService.getQueryWriteInfo(locationHandle).getWritePath(); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/aws/athena/projection/InjectedProjection.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/aws/athena/projection/InjectedProjection.java index 1f8791a7f7c2..3c0d3c9ae423 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/aws/athena/projection/InjectedProjection.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/aws/athena/projection/InjectedProjection.java @@ -34,10 +34,8 @@ public InjectedProjection(String columnName) @Override public List getProjectedValues(Optional partitionValueFilter) { - if (partitionValueFilter.isEmpty()) { - throw invalidProjectionException(getColumnName(), "Injected projection requires single predicate for it's column in where clause"); - } - Domain domain = partitionValueFilter.get(); + Domain domain = partitionValueFilter + .orElseThrow(() -> invalidProjectionException(getColumnName(), "Injected projection requires single predicate for it's column in where clause")); Type type = domain.getType(); if (!domain.isNullableSingleValue() || !canConvertSqlTypeToStringForParts(type, true)) { throw invalidProjectionException(getColumnName(), "Injected projection requires single predicate for it's column in where clause. Currently provided can't be converted to single partition."); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/DecimalCoercers.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/DecimalCoercers.java index 27fd67f16b5d..bb93e4851c8b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/DecimalCoercers.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/DecimalCoercers.java @@ -49,18 +49,12 @@ public static Function createDecimalToDecimalCoercer(DecimalType f if (toType.isShort()) { return new ShortDecimalToShortDecimalCoercer(fromType, toType); } - else { - return new ShortDecimalToLongDecimalCoercer(fromType, toType); - } + return new ShortDecimalToLongDecimalCoercer(fromType, toType); } - else { - if (toType.isShort()) { - return new LongDecimalToShortDecimalCoercer(fromType, toType); - } - else { - return new LongDecimalToLongDecimalCoercer(fromType, toType); - } + if (toType.isShort()) { + return new LongDecimalToShortDecimalCoercer(fromType, toType); } + return new LongDecimalToLongDecimalCoercer(fromType, toType); } private static class ShortDecimalToShortDecimalCoercer @@ -153,9 +147,7 @@ public static Function createDecimalToDoubleCoercer(DecimalType fr if (fromType.isShort()) { return new ShortDecimalToDoubleCoercer(fromType); } - else { - return new LongDecimalToDoubleCoercer(fromType); - } + return new LongDecimalToDoubleCoercer(fromType); } private static class ShortDecimalToDoubleCoercer @@ -198,9 +190,7 @@ public static Function createDecimalToRealCoercer(DecimalType from if (fromType.isShort()) { return new ShortDecimalToRealCoercer(fromType); } - else { - return new LongDecimalToRealCoercer(fromType); - } + return new LongDecimalToRealCoercer(fromType); } private static class ShortDecimalToRealCoercer @@ -243,9 +233,7 @@ public static Function createDoubleToDecimalCoercer(DecimalType to if (toType.isShort()) { return new DoubleToShortDecimalCoercer(toType); } - else { - return new DoubleToLongDecimalCoercer(toType); - } + return new DoubleToLongDecimalCoercer(toType); } private static class DoubleToShortDecimalCoercer @@ -285,9 +273,7 @@ public static Function createRealToDecimalCoercer(DecimalType toTy if (toType.isShort()) { return new RealToShortDecimalCoercer(toType); } - else { - return new RealToLongDecimalCoercer(toType); - } + return new RealToLongDecimalCoercer(toType); } private static class RealToShortDecimalCoercer diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/CachingDirectoryLister.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/CachingDirectoryLister.java index 3c1bacfaef32..71123cbbd40b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/CachingDirectoryLister.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/CachingDirectoryLister.java @@ -127,9 +127,7 @@ private static RemoteIterator createListingRemoteIterator(Fil if (cacheKey.isRecursiveFilesOnly()) { return fs.listFiles(cacheKey.getPath(), true); } - else { - return fs.listLocatedStatus(cacheKey.getPath()); - } + return fs.listLocatedStatus(cacheKey.getPath()); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/TransactionScopeCachingDirectoryLister.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/TransactionScopeCachingDirectoryLister.java index fe462b408dd4..37aa0da0c6e2 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/TransactionScopeCachingDirectoryLister.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/fs/TransactionScopeCachingDirectoryLister.java @@ -107,9 +107,7 @@ private RemoteIterator createListingRemoteIterator(FileSystem if (cacheKey.isRecursiveFilesOnly()) { return delegate.listFilesRecursively(fs, table, cacheKey.getPath()); } - else { - return delegate.list(fs, table, cacheKey.getPath()); - } + return delegate.list(fs, table, cacheKey.getPath()); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HiveMetastoreModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HiveMetastoreModule.java index 0c2411703ddc..943122ace1bc 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HiveMetastoreModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HiveMetastoreModule.java @@ -14,10 +14,12 @@ package io.trino.plugin.hive.metastore; import com.google.inject.Binder; +import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Singleton; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.hive.AllowHiveTableRename; import io.trino.plugin.hive.HideDeltaLakeTables; import io.trino.plugin.hive.metastore.file.FileMetastoreModule; import io.trino.plugin.hive.metastore.glue.GlueMetastoreModule; @@ -42,9 +44,7 @@ protected void setup(Binder binder) { if (metastore.isPresent()) { binder.bind(HiveMetastoreFactory.class).annotatedWith(RawHiveMetastoreFactory.class).toInstance(HiveMetastoreFactory.ofInstance(metastore.get())); - MetastoreTypeConfig metastoreTypeConfig = new MetastoreTypeConfig(); - metastoreTypeConfig.setMetastoreType("provided"); - binder.bind(MetastoreTypeConfig.class).toInstance(metastoreTypeConfig); + binder.bind(Key.get(boolean.class, AllowHiveTableRename.class)).toInstance(true); } else { bindMetastoreModule("thrift", new ThriftMetastoreModule()); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePageSinkMetadataProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePageSinkMetadataProvider.java index 750f774509f1..dc1f14d0a9b7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePageSinkMetadataProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePageSinkMetadataProvider.java @@ -54,8 +54,6 @@ public Optional getPartition(List partitionValues) if (modifiedPartition == null) { return delegate.getPartition(schemaTableName.getSchemaName(), schemaTableName.getTableName(), partitionValues); } - else { - return modifiedPartition; - } + return modifiedPartition; } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/MetastoreUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/MetastoreUtil.java index 9e5c16be3808..219044f2d716 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/MetastoreUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/MetastoreUtil.java @@ -276,9 +276,7 @@ private static ProtectMode getProtectMode(Map parameters) if (!parameters.containsKey(ProtectMode.PARAMETER_NAME)) { return new ProtectMode(); } - else { - return getProtectModeFromString(parameters.get(ProtectMode.PARAMETER_NAME)); - } + return getProtectModeFromString(parameters.get(ProtectMode.PARAMETER_NAME)); } public static void verifyOnline(SchemaTableName tableName, Optional partitionName, ProtectMode protectMode, Map parameters) @@ -384,29 +382,29 @@ public static String sqlScalarToString(Type type, Object value, String nullStrin if (value == null) { return nullString; } - else if (type instanceof CharType) { + if (type instanceof CharType) { Slice slice = (Slice) value; return padSpaces(slice, (CharType) type).toStringUtf8(); } - else if (type instanceof VarcharType) { + if (type instanceof VarcharType) { Slice slice = (Slice) value; return slice.toStringUtf8(); } - else if (type instanceof DecimalType && !((DecimalType) type).isShort()) { + if (type instanceof DecimalType && !((DecimalType) type).isShort()) { return Decimals.toString((Int128) value, ((DecimalType) type).getScale()); } - else if (type instanceof DecimalType && ((DecimalType) type).isShort()) { + if (type instanceof DecimalType && ((DecimalType) type).isShort()) { return Decimals.toString((long) value, ((DecimalType) type).getScale()); } - else if (type instanceof DateType) { + if (type instanceof DateType) { DateTimeFormatter dateTimeFormatter = ISODateTimeFormat.date().withZoneUTC(); return dateTimeFormatter.print(TimeUnit.DAYS.toMillis((long) value)); } - else if (type instanceof TimestampType) { + if (type instanceof TimestampType) { // we throw on this type as we don't have timezone. Callers should not ask for this conversion type, but document for possible future work (?) throw new TrinoException(NOT_SUPPORTED, "TimestampType conversion to scalar expressions is not supported"); } - else if (type instanceof TinyintType + if (type instanceof TinyintType || type instanceof SmallintType || type instanceof IntegerType || type instanceof BigintType @@ -415,9 +413,7 @@ else if (type instanceof TinyintType || type instanceof BooleanType) { return value.toString(); } - else { - throw new TrinoException(NOT_SUPPORTED, format("Unsupported partition key type: %s", type.getDisplayName())); - } + throw new TrinoException(NOT_SUPPORTED, format("Unsupported partition key type: %s", type.getDisplayName())); } /** diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java index 63c7bf450d0f..a27fce3af945 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java @@ -702,19 +702,17 @@ private boolean isAcidTransactionRunning() public synchronized void truncateUnpartitionedTable(ConnectorSession session, String databaseName, String tableName) { checkReadable(); - Optional
table = getTable(databaseName, tableName); SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); - if (table.isEmpty()) { - throw new TableNotFoundException(schemaTableName); - } - if (!table.get().getTableType().equals(MANAGED_TABLE.toString())) { + Table table = getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(schemaTableName)); + if (!table.getTableType().equals(MANAGED_TABLE.toString())) { throw new TrinoException(NOT_SUPPORTED, "Cannot delete from non-managed Hive table"); } - if (!table.get().getPartitionColumns().isEmpty()) { + if (!table.getPartitionColumns().isEmpty()) { throw new IllegalArgumentException("Table is partitioned"); } - Path path = new Path(table.get().getStorage().getLocation()); + Path path = new Path(table.getStorage().getLocation()); HdfsContext context = new HdfsContext(session); setExclusive((delegate, hdfsEnvironment) -> { RecursiveDeleteResult recursiveDeleteResult = recursiveDeleteFiles(hdfsEnvironment, context, path, ImmutableSet.of(""), false); @@ -926,11 +924,8 @@ private Optional> doGetPartitionNames( partitionNames = ImmutableList.of(); break; case PRE_EXISTING_TABLE: - Optional> partitionNameResult = delegate.getPartitionNamesByFilter(databaseName, tableName, columnNames, partitionKeysFilter); - if (partitionNameResult.isEmpty()) { - throw new TrinoException(TRANSACTION_CONFLICT, format("Table '%s.%s' was dropped by another transaction", databaseName, tableName)); - } - partitionNames = partitionNameResult.get(); + partitionNames = delegate.getPartitionNamesByFilter(databaseName, tableName, columnNames, partitionKeysFilter) + .orElseThrow(() -> new TrinoException(TRANSACTION_CONFLICT, format("Table '%s.%s' was dropped by another transaction", databaseName, tableName))); break; default: throw new UnsupportedOperationException("Unknown table source"); @@ -2025,15 +2020,13 @@ private void prepareAlterPartition(HdfsContext hdfsContext, String queryId, Part Partition partition = partitionAndMore.getPartition(); partitionsToInvalidate.add(partition); String targetLocation = partition.getStorage().getLocation(); - Optional oldPartition = delegate.getPartition(partition.getDatabaseName(), partition.getTableName(), partition.getValues()); - if (oldPartition.isEmpty()) { - throw new TrinoException( - TRANSACTION_CONFLICT, - format("The partition that this transaction modified was deleted in another transaction. %s %s", partition.getTableName(), partition.getValues())); - } + Partition oldPartition = delegate.getPartition(partition.getDatabaseName(), partition.getTableName(), partition.getValues()) + .orElseThrow(() -> new TrinoException( + TRANSACTION_CONFLICT, + format("The partition that this transaction modified was deleted in another transaction. %s %s", partition.getTableName(), partition.getValues()))); String partitionName = getPartitionName(partition.getDatabaseName(), partition.getTableName(), partition.getValues()); PartitionStatistics oldPartitionStatistics = getExistingPartitionStatistics(partition, partitionName); - String oldPartitionLocation = oldPartition.get().getStorage().getLocation(); + String oldPartitionLocation = oldPartition.getStorage().getLocation(); Path oldPartitionPath = new Path(oldPartitionLocation); cleanExtraOutputFiles(hdfsContext, queryId, partitionAndMore); @@ -2077,7 +2070,7 @@ private void prepareAlterPartition(HdfsContext hdfsContext, String queryId, Part // because metadata might change: e.g. storage format, column types, etc alterPartitionOperations.add(new AlterPartitionOperation( new PartitionWithStatistics(partition, partitionName, partitionAndMore.getStatisticsUpdate()), - new PartitionWithStatistics(oldPartition.get(), partitionName, oldPartitionStatistics))); + new PartitionWithStatistics(oldPartition, partitionName, oldPartitionStatistics))); } private void cleanExtraOutputFiles(HdfsContext hdfsContext, String queryId, PartitionAndMore partitionAndMore) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/AlluxioMetastoreModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/AlluxioMetastoreModule.java index 1c4e660d696c..67d737227f23 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/AlluxioMetastoreModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/AlluxioMetastoreModule.java @@ -21,9 +21,11 @@ import alluxio.conf.PropertyKey; import alluxio.master.MasterClientContext; import com.google.inject.Binder; +import com.google.inject.Key; import com.google.inject.Provides; import com.google.inject.Scopes; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.hive.AllowHiveTableRename; import io.trino.plugin.hive.metastore.HiveMetastore; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.plugin.hive.metastore.RawHiveMetastoreFactory; @@ -42,6 +44,7 @@ protected void setup(Binder binder) configBinder(binder).bindConfig(AlluxioHiveMetastoreConfig.class); binder.bind(HiveMetastoreFactory.class).annotatedWith(RawHiveMetastoreFactory.class).to(AlluxioHiveMetastoreFactory.class).in(Scopes.SINGLETON); + binder.bind(Key.get(boolean.class, AllowHiveTableRename.class)).toInstance(false); } @Provides diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/ProtoUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/ProtoUtils.java index 14b0ca3bae75..de3397f019b0 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/ProtoUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/alluxio/ProtoUtils.java @@ -250,9 +250,7 @@ public static HiveColumnStatistics fromProto(ColumnStatisticsData columnStatisti getTotalSizeInBytes(averageColumnLength, rowCount, nullsCount), nullsCount); } - else { - throw new TrinoException(HIVE_INVALID_METADATA, "Invalid column statistics data: " + columnStatistics); - } + throw new TrinoException(HIVE_INVALID_METADATA, "Invalid column statistics data: " + columnStatistics); } static Column fromProto(alluxio.grpc.table.FieldSchema column) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/file/FileMetastoreModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/file/FileMetastoreModule.java index a54e296a03a9..71b0a7450679 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/file/FileMetastoreModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/file/FileMetastoreModule.java @@ -14,8 +14,10 @@ package io.trino.plugin.hive.metastore.file; import com.google.inject.Binder; +import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.Scopes; +import io.trino.plugin.hive.AllowHiveTableRename; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.plugin.hive.metastore.RawHiveMetastoreFactory; @@ -29,5 +31,6 @@ public void configure(Binder binder) { configBinder(binder).bindConfig(FileHiveMetastoreConfig.class); binder.bind(HiveMetastoreFactory.class).annotatedWith(RawHiveMetastoreFactory.class).to(FileHiveMetastoreFactory.class).in(Scopes.SINGLETON); + binder.bind(Key.get(boolean.class, AllowHiveTableRename.class)).toInstance(true); } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DefaultGlueColumnStatisticsProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DefaultGlueColumnStatisticsProvider.java index 67f9695e9dbe..0ca852f4a08c 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DefaultGlueColumnStatisticsProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DefaultGlueColumnStatisticsProvider.java @@ -200,15 +200,15 @@ private boolean isGlueWritable(ColumnStatistics stats) DateColumnStatisticsData data = statisticsData.getDateColumnStatisticsData(); return data.getMaximumValue() != null && data.getMinimumValue() != null; } - else if (columnType.equals(ColumnStatisticsType.DECIMAL.toString())) { + if (columnType.equals(ColumnStatisticsType.DECIMAL.toString())) { DecimalColumnStatisticsData data = statisticsData.getDecimalColumnStatisticsData(); return data.getMaximumValue() != null && data.getMinimumValue() != null; } - else if (columnType.equals(ColumnStatisticsType.DOUBLE.toString())) { + if (columnType.equals(ColumnStatisticsType.DOUBLE.toString())) { DoubleColumnStatisticsData data = statisticsData.getDoubleColumnStatisticsData(); return data.getMaximumValue() != null && data.getMinimumValue() != null; } - else if (columnType.equals(ColumnStatisticsType.LONG.toString())) { + if (columnType.equals(ColumnStatisticsType.LONG.toString())) { LongColumnStatisticsData data = statisticsData.getLongColumnStatisticsData(); return data.getMaximumValue() != null && data.getMinimumValue() != null; } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java index e7c580fab7fe..d441e6e97497 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java @@ -26,6 +26,7 @@ import io.airlift.concurrent.BoundedExecutor; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.base.CatalogName; +import io.trino.plugin.hive.AllowHiveTableRename; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.plugin.hive.metastore.RawHiveMetastoreFactory; @@ -65,6 +66,8 @@ protected void setup(Binder binder) binder.bind(GlueHiveMetastoreFactory.class).in(Scopes.SINGLETON); newExporter(binder).export(GlueHiveMetastoreFactory.class).as(generator -> generator.generatedNameOf(GlueHiveMetastore.class)); + binder.bind(Key.get(boolean.class, AllowHiveTableRename.class)).toInstance(false); + install(conditionalModule( HiveConfig.class, HiveConfig::isTableStatisticsEnabled, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/converter/GlueToTrinoConverter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/converter/GlueToTrinoConverter.java index 329b27822881..353e38ceee02 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/converter/GlueToTrinoConverter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/converter/GlueToTrinoConverter.java @@ -117,9 +117,7 @@ private static Column convertColumn(com.amazonaws.services.glue.model.Column glu //TODO(https://github.com/trinodb/trino/issues/7240) Add tests return new Column(glueColumn.getName(), HiveType.HIVE_STRING, Optional.ofNullable(glueColumn.getComment())); } - else { - return new Column(glueColumn.getName(), HiveType.valueOf(glueColumn.getType().toLowerCase(Locale.ENGLISH)), Optional.ofNullable(glueColumn.getComment())); - } + return new Column(glueColumn.getName(), HiveType.valueOf(glueColumn.getType().toLowerCase(Locale.ENGLISH)), Optional.ofNullable(glueColumn.getComment())); } private static List convertColumns(List glueColumns, String serde) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/BridgingHiveMetastore.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/BridgingHiveMetastore.java index 59167216f737..b3f1705319b8 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/BridgingHiveMetastore.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/BridgingHiveMetastore.java @@ -216,11 +216,8 @@ public void replaceTable(String databaseName, String tableName, Table newTable, @Override public void renameTable(String databaseName, String tableName, String newDatabaseName, String newTableName) { - Optional source = delegate.getTable(databaseName, tableName); - if (source.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } - org.apache.hadoop.hive.metastore.api.Table table = source.get(); + org.apache.hadoop.hive.metastore.api.Table table = delegate.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); table.setDbName(newDatabaseName); table.setTableName(newTableName); alterTable(databaseName, tableName, table); @@ -229,11 +226,8 @@ public void renameTable(String databaseName, String tableName, String newDatabas @Override public void commentTable(String databaseName, String tableName, Optional comment) { - Optional source = delegate.getTable(databaseName, tableName); - if (source.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } - org.apache.hadoop.hive.metastore.api.Table table = source.get(); + org.apache.hadoop.hive.metastore.api.Table table = delegate.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); Map parameters = table.getParameters().entrySet().stream() .filter(entry -> !entry.getKey().equals(TABLE_COMMENT)) @@ -265,11 +259,8 @@ public void setTableOwner(String databaseName, String tableName, HivePrincipal p @Override public void commentColumn(String databaseName, String tableName, String columnName, Optional comment) { - Optional source = delegate.getTable(databaseName, tableName); - if (source.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } - org.apache.hadoop.hive.metastore.api.Table table = source.get(); + org.apache.hadoop.hive.metastore.api.Table table = delegate.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); for (FieldSchema fieldSchema : table.getSd().getCols()) { if (fieldSchema.getName().equals(columnName)) { @@ -288,11 +279,8 @@ public void commentColumn(String databaseName, String tableName, String columnNa @Override public void addColumn(String databaseName, String tableName, String columnName, HiveType columnType, String columnComment) { - Optional source = delegate.getTable(databaseName, tableName); - if (source.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } - org.apache.hadoop.hive.metastore.api.Table table = source.get(); + org.apache.hadoop.hive.metastore.api.Table table = delegate.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); table.getSd().getCols().add( new FieldSchema(columnName, columnType.getHiveTypeName().toString(), columnComment)); alterTable(databaseName, tableName, table); @@ -301,11 +289,8 @@ public void addColumn(String databaseName, String tableName, String columnName, @Override public void renameColumn(String databaseName, String tableName, String oldColumnName, String newColumnName) { - Optional source = delegate.getTable(databaseName, tableName); - if (source.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } - org.apache.hadoop.hive.metastore.api.Table table = source.get(); + org.apache.hadoop.hive.metastore.api.Table table = delegate.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); for (FieldSchema fieldSchema : table.getPartitionKeys()) { if (fieldSchema.getName().equals(oldColumnName)) { throw new TrinoException(NOT_SUPPORTED, "Renaming partition columns is not supported"); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftHiveMetastore.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftHiveMetastore.java index 9160e8b5e9f8..3a53a05663aa 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftHiveMetastore.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftHiveMetastore.java @@ -139,8 +139,6 @@ public class ThriftHiveMetastore private static final String DEFAULT_METASTORE_USER = "presto"; - private final HdfsContext hdfsContext = new HdfsContext(ConnectorIdentity.ofUser(DEFAULT_METASTORE_USER)); - private final Optional identity; private final HdfsEnvironment hdfsEnvironment; private final IdentityAwareMetastoreClientFactory metastoreClientFactory; @@ -955,7 +953,7 @@ public void dropTable(String databaseName, String tableName, boolean deleteData) client.dropTable(databaseName, tableName, deleteData); String tableLocation = table.getSd().getLocation(); if (deleteFilesOnDrop && deleteData && isManagedTable(table) && !isNullOrEmpty(tableLocation)) { - deleteDirRecursive(hdfsContext, hdfsEnvironment, new Path(tableLocation)); + deleteDirRecursive(new Path(tableLocation)); } } return null; @@ -972,9 +970,11 @@ public void dropTable(String databaseName, String tableName, boolean deleteData) } } - private static void deleteDirRecursive(HdfsContext context, HdfsEnvironment hdfsEnvironment, Path path) + private void deleteDirRecursive(Path path) { try { + HdfsContext context = new HdfsContext(identity.orElseGet(() -> + ConnectorIdentity.ofUser(DEFAULT_METASTORE_USER))); hdfsEnvironment.getFileSystem(context, path).delete(path, true); } catch (IOException | RuntimeException e) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastore.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastore.java index 0698fb3fe78f..414d1a68ffa7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastore.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastore.java @@ -120,16 +120,14 @@ public interface ThriftMetastore default Optional> getFields(String databaseName, String tableName) { - Optional
table = getTable(databaseName, tableName); - if (table.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } + Table table = getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); - if (table.get().getSd() == null) { + if (table.getSd() == null) { throw new TrinoException(HIVE_INVALID_METADATA, "Table is missing storage descriptor"); } - return Optional.of(table.get().getSd().getCols()); + return Optional.of(table.getSd().getCols()); } default long openTransaction(AcidTransactionOwner transactionOwner) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreModule.java index 7c2e3d4a23e6..caa324135c0e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreModule.java @@ -14,9 +14,11 @@ package io.trino.plugin.hive.metastore.thrift; import com.google.inject.Binder; +import com.google.inject.Key; import com.google.inject.Scopes; import com.google.inject.multibindings.OptionalBinder; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.hive.AllowHiveTableRename; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.plugin.hive.metastore.RawHiveMetastoreFactory; @@ -44,6 +46,8 @@ protected void setup(Binder binder) .to(BridgingHiveMetastoreFactory.class) .in(Scopes.SINGLETON); + binder.bind(Key.get(boolean.class, AllowHiveTableRename.class)).toInstance(true); + install(new ThriftMetastoreAuthenticationModule()); } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java index ac3ef86446ba..7a0afb5aed76 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java @@ -552,9 +552,7 @@ public static HiveColumnStatistics fromMetastoreApiColumnStatistics(ColumnStatis getTotalSizeInBytes(averageColumnLength, rowCount, nullsCount), nullsCount); } - else { - throw new TrinoException(HIVE_INVALID_METADATA, "Invalid column statistics data: " + columnStatistics); - } + throw new TrinoException(HIVE_INVALID_METADATA, "Invalid column statistics data: " + columnStatistics); } private static Optional fromMetastoreDate(Date date) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index 67489e06e1f4..1a55b7bbbf77 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -285,7 +285,7 @@ && predicateMatches(parquetPredicate, block, dataSource, descriptorsByPath, parq timeZone, newSimpleAggregatedMemoryContext(), options, - parquetPredicate, + Optional.of(parquetPredicate), columnIndexes.build()); ConnectorPageSource parquetPageSource = new ParquetPageSource( diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3select/S3SelectRecordCursorProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3select/S3SelectRecordCursorProvider.java index de8c1cfa4082..d06d1b057923 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3select/S3SelectRecordCursorProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3select/S3SelectRecordCursorProvider.java @@ -117,10 +117,8 @@ public Optional createRecordCursor( RecordCursor cursor = new S3SelectRecordCursor<>(configuration, path, recordReader.get(), length, schema, readerColumns); return Optional.of(new ReaderRecordCursorWithProjections(cursor, projectedReaderColumns)); } - else { - // unsupported serdes - return Optional.empty(); - } + // unsupported serdes + return Optional.empty(); } private static boolean hasFilters( diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java index c0beb4e9150f..e3fbeac07126 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java @@ -32,7 +32,6 @@ import java.util.Optional; import java.util.Set; -import static io.trino.spi.function.FunctionKind.TABLE; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentTable; @@ -399,9 +398,13 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche @Override public void checkCanExecuteFunction(ConnectorSecurityContext context, FunctionKind functionKind, SchemaRoutineName function) { - if (functionKind == TABLE) { - denyExecuteFunction(function.toString()); + switch (functionKind) { + case SCALAR, AGGREGATE, WINDOW: + return; + case TABLE: + denyExecuteFunction(function.toString()); } + throw new UnsupportedOperationException("Unsupported function kind: " + functionKind); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java index ddfa2fe406af..609b89fa4d60 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java @@ -55,7 +55,6 @@ import static io.trino.plugin.hive.metastore.thrift.ThriftMetastoreUtil.listApplicableRoles; import static io.trino.plugin.hive.metastore.thrift.ThriftMetastoreUtil.listEnabledPrincipals; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; -import static io.trino.spi.function.FunctionKind.TABLE; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentTable; @@ -583,9 +582,16 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche @Override public void checkCanExecuteFunction(ConnectorSecurityContext context, FunctionKind functionKind, SchemaRoutineName function) { - if (functionKind == TABLE && !isAdmin(context)) { - denyExecuteFunction(function.toString()); + switch (functionKind) { + case SCALAR, AGGREGATE, WINDOW: + return; + case TABLE: + if (isAdmin(context)) { + return; + } + denyExecuteFunction(function.toString()); } + throw new UnsupportedOperationException("Unsupported function kind: " + functionKind); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java index 39168b5a2675..22d1e5f4ce2d 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java @@ -20,7 +20,6 @@ import io.trino.spi.TrinoException; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; -import io.trino.spi.type.NamedTypeSignature; import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -163,11 +162,8 @@ public static TypeInfo toTypeInfo(Type type) if (!parameter.isNamedTypeSignature()) { throw new IllegalArgumentException(format("Expected all parameters to be named type, but got %s", parameter)); } - NamedTypeSignature namedTypeSignature = parameter.getNamedTypeSignature(); - if (namedTypeSignature.getName().isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, format("Anonymous row type is not supported in Hive. Please give each field a name: %s", type)); - } - fieldNames.add(namedTypeSignature.getName().get()); + fieldNames.add(parameter.getNamedTypeSignature().getName() + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, format("Anonymous row type is not supported in Hive. Please give each field a name: %s", type)))); } return getStructTypeInfo( fieldNames.build(), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java index 02945e194231..38f11ad8f515 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java @@ -556,12 +556,10 @@ public static NullableValue parsePartitionValue(String partitionName, String val } return NullableValue.of(decimalType, shortDecimalPartitionKey(value, decimalType, partitionName)); } - else { - if (value.isEmpty()) { - return NullableValue.of(decimalType, Int128.ZERO); - } - return NullableValue.of(decimalType, longDecimalPartitionKey(value, decimalType, partitionName)); + if (value.isEmpty()) { + return NullableValue.of(decimalType, Int128.ZERO); } + return NullableValue.of(decimalType, longDecimalPartitionKey(value, decimalType, partitionName)); } if (BOOLEAN.equals(type)) { @@ -685,9 +683,7 @@ public static Optional getDecimalType(String hiveTypeName) int scale = parseInt(matcher.group(DECIMAL_SCALE_GROUP)); return Optional.of(createDecimalType(precision, scale)); } - else { - return Optional.empty(); - } + return Optional.empty(); } public static boolean isArrayType(Type type) @@ -988,50 +984,50 @@ else if (isPartitionColumnHandle(columnHandle)) { if (isHiveNull(bytes)) { return NullableValue.asNull(type); } - else if (type.equals(BOOLEAN)) { + if (type.equals(BOOLEAN)) { return NullableValue.of(type, booleanPartitionKey(columnValue, name)); } - else if (type.equals(BIGINT)) { + if (type.equals(BIGINT)) { return NullableValue.of(type, bigintPartitionKey(columnValue, name)); } - else if (type.equals(INTEGER)) { + if (type.equals(INTEGER)) { return NullableValue.of(type, integerPartitionKey(columnValue, name)); } - else if (type.equals(SMALLINT)) { + if (type.equals(SMALLINT)) { return NullableValue.of(type, smallintPartitionKey(columnValue, name)); } - else if (type.equals(TINYINT)) { + if (type.equals(TINYINT)) { return NullableValue.of(type, tinyintPartitionKey(columnValue, name)); } - else if (type.equals(REAL)) { + if (type.equals(REAL)) { return NullableValue.of(type, floatPartitionKey(columnValue, name)); } - else if (type.equals(DOUBLE)) { + if (type.equals(DOUBLE)) { return NullableValue.of(type, doublePartitionKey(columnValue, name)); } - else if (type instanceof VarcharType) { + if (type instanceof VarcharType) { return NullableValue.of(type, varcharPartitionKey(columnValue, name, type)); } - else if (type instanceof CharType) { + if (type instanceof CharType) { return NullableValue.of(type, charPartitionKey(columnValue, name, type)); } - else if (type.equals(DATE)) { + if (type.equals(DATE)) { return NullableValue.of(type, datePartitionKey(columnValue, name)); } - else if (type.equals(TIMESTAMP_MILLIS)) { + if (type.equals(TIMESTAMP_MILLIS)) { return NullableValue.of(type, timestampPartitionKey(columnValue, name)); } - else if (type.equals(TIMESTAMP_TZ_MILLIS)) { + if (type.equals(TIMESTAMP_TZ_MILLIS)) { // used for $file_modified_time return NullableValue.of(type, packDateTimeWithZone(floorDiv(timestampPartitionKey(columnValue, name), MICROSECONDS_PER_MILLISECOND), DateTimeZone.getDefault().getID())); } - else if (isShortDecimal(type)) { + if (isShortDecimal(type)) { return NullableValue.of(type, shortDecimalPartitionKey(columnValue, (DecimalType) type, name)); } - else if (isLongDecimal(type)) { + if (isLongDecimal(type)) { return NullableValue.of(type, longDecimalPartitionKey(columnValue, (DecimalType) type, name)); } - else if (type.equals(VarbinaryType.VARBINARY)) { + if (type.equals(VarbinaryType.VARBINARY)) { return NullableValue.of(type, utf8Slice(columnValue)); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java index cd58aee6e5fa..c8e4019669db 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java @@ -448,12 +448,10 @@ public static Path getTableDefaultLocation(HdfsContext context, SemiTransactiona public static Path getTableDefaultLocation(Database database, HdfsContext context, HdfsEnvironment hdfsEnvironment, String schemaName, String tableName) { - Optional location = database.getLocation(); - if (location.isEmpty()) { - throw new TrinoException(HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location is not set", schemaName)); - } + String location = database.getLocation() + .orElseThrow(() -> new TrinoException(HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location is not set", schemaName))); - Path databasePath = new Path(location.get()); + Path databasePath = new Path(location); if (!isS3FileSystem(context, hdfsEnvironment, databasePath)) { if (!pathExists(context, hdfsEnvironment, databasePath)) { throw new TrinoException(HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location does not exist: %s", schemaName, databasePath)); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerDeUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerDeUtils.java index 19324bb7d904..1633f04846b6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerDeUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerDeUtils.java @@ -199,10 +199,8 @@ private static Block serializeList(Type type, BlockBuilder builder, Object objec builder.closeEntry(); return null; } - else { - Block resultBlock = currentBuilder.build(); - return resultBlock; - } + Block resultBlock = currentBuilder.build(); + return resultBlock; } private static Block serializeMap(Type type, BlockBuilder builder, Object object, MapObjectInspector inspector, boolean filterNullMapKeys) @@ -240,9 +238,7 @@ private static Block serializeMap(Type type, BlockBuilder builder, Object object if (builderSynthesized) { return (Block) type.getObject(builder, 0); } - else { - return null; - } + return null; } private static Block serializeStruct(Type type, BlockBuilder builder, Object object, StructObjectInspector inspector) @@ -273,9 +269,7 @@ private static Block serializeStruct(Type type, BlockBuilder builder, Object obj if (builderSynthesized) { return (Block) type.getObject(builder, 0); } - else { - return null; - } + return null; } // Use row blocks to represent union objects when reading diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java index d8dee4c71b19..d451861e8090 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java @@ -41,7 +41,6 @@ import io.trino.plugin.hive.metastore.HivePrincipal; import io.trino.plugin.hive.metastore.HivePrivilegeInfo; import io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege; -import io.trino.plugin.hive.metastore.MetastoreTypeConfig; import io.trino.plugin.hive.metastore.Partition; import io.trino.plugin.hive.metastore.PartitionWithStatistics; import io.trino.plugin.hive.metastore.PrincipalPrivileges; @@ -874,7 +873,7 @@ public Optional getMaterializedView(Connect countingDirectoryLister, 1000, new PartitionProjectionService(hiveConfig, ImmutableMap.of(), new TestingTypeManager()), - new MetastoreTypeConfig()); + true); transactionManager = new HiveTransactionManager(metadataFactory); splitManager = new HiveSplitManager( transactionManager, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileFormats.java index 5ba17c0cccca..64d3e564c076 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileFormats.java @@ -727,9 +727,7 @@ public static Object getFieldFromCursor(RecordCursor cursor, Type type, int fiel if (decimalType.isShort()) { return BigInteger.valueOf(cursor.getLong(field)); } - else { - return ((Int128) cursor.getObject(field)).toBigInteger(); - } + return ((Int128) cursor.getObject(field)).toBigInteger(); } throw new RuntimeException("unknown type"); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileSystem.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileSystem.java index 4102f112ffe2..d46ab6db2f33 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileSystem.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFileSystem.java @@ -40,7 +40,6 @@ import io.trino.plugin.hive.metastore.HiveMetastore; import io.trino.plugin.hive.metastore.HiveMetastoreConfig; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; -import io.trino.plugin.hive.metastore.MetastoreTypeConfig; import io.trino.plugin.hive.metastore.PrincipalPrivileges; import io.trino.plugin.hive.metastore.StorageFormat; import io.trino.plugin.hive.metastore.Table; @@ -216,7 +215,7 @@ protected void setup(String host, int port, String databaseName, boolean s3Selec SqlStandardAccessControlMetadata::new, new FileSystemDirectoryLister(), new PartitionProjectionService(config, ImmutableMap.of(), new TestingTypeManager()), - new MetastoreTypeConfig()); + true); transactionManager = new HiveTransactionManager(metadataFactory); splitManager = new HiveSplitManager( transactionManager, @@ -611,15 +610,13 @@ public void createTable(Table table, PrincipalPrivileges privileges) public void dropTable(String databaseName, String tableName, boolean deleteData) { try { - Optional
table = getTable(databaseName, tableName); - if (table.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } + Table table = getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); // hack to work around the metastore not being configured for S3 or other FS List locations = listAllDataPaths(databaseName, tableName); - Table.Builder tableBuilder = Table.builder(table.get()); + Table.Builder tableBuilder = Table.builder(table); tableBuilder.getStorageBuilder().setLocation("/"); // drop table @@ -641,12 +638,9 @@ public void dropTable(String databaseName, String tableName, boolean deleteData) public void updateTableLocation(String databaseName, String tableName, String location) { - Optional
table = getTable(databaseName, tableName); - if (table.isEmpty()) { - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); - } - - Table.Builder tableBuilder = Table.builder(table.get()); + Table table = getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); + Table.Builder tableBuilder = Table.builder(table); tableBuilder.getStorageBuilder().setLocation(location); // NOTE: this clears the permissions diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java index eac2452834e4..b39e347310c1 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java @@ -163,12 +163,10 @@ private static Block createInputBlock(List data, Type type) if (type instanceof RowType) { return new LazyBlock(data.size(), () -> createRowBlockWithLazyNestedBlocks(data, (RowType) type)); } - else if (BIGINT.equals(type)) { + if (BIGINT.equals(type)) { return new LazyBlock(positionCount, () -> createLongArrayBlock(data)); } - else { - throw new UnsupportedOperationException(); - } + throw new UnsupportedOperationException(); } private static Block createRowBlockWithLazyNestedBlocks(List data, RowType rowType) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/BenchmarkProjectionPushdownHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/BenchmarkProjectionPushdownHive.java index f1d044cef9d6..1a114ced0e1b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/BenchmarkProjectionPushdownHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/BenchmarkProjectionPushdownHive.java @@ -274,14 +274,14 @@ private Block createBlock(Type type, int rowCount) return RowBlock.fromFieldBlocks(rowCount, Optional.empty(), fieldBlocks); } - else if (type instanceof VarcharType) { + if (type instanceof VarcharType) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, rowCount); for (int i = 0; i < rowCount; i++) { VARCHAR.writeString(builder, generateRandomString(random, 500)); } return builder.build(); } - else if (type instanceof ArrayType) { + if (type instanceof ArrayType) { ArrayType arrayType = (ArrayType) type; Type elementType = arrayType.getElementType(); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/recording/TestRecordingHiveMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/recording/TestRecordingHiveMetastore.java index 586d41f36698..97c44abd4e7f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/recording/TestRecordingHiveMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/recording/TestRecordingHiveMetastore.java @@ -311,7 +311,7 @@ public Optional getPartition(Table table, List partitionValue if (partitionValues.equals(ImmutableList.of("value"))) { return Optional.of(PARTITION); } - else if (partitionValues.equals(ImmutableList.of("other_value"))) { + if (partitionValues.equals(ImmutableList.of("other_value"))) { return Optional.of(OTHER_PARTITION); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/MapKeyValuesSchemaConverter.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/MapKeyValuesSchemaConverter.java index c0c9347c50b9..2ee37c0ed82c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/MapKeyValuesSchemaConverter.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/MapKeyValuesSchemaConverter.java @@ -71,72 +71,68 @@ private static Type convertType(String name, TypeInfo typeInfo, Repetition repet return Types.primitive(PrimitiveTypeName.BINARY, repetition).as(LogicalTypeAnnotation.stringType()) .named(name); } - else if (typeInfo.equals(TypeInfoFactory.intTypeInfo) || + if (typeInfo.equals(TypeInfoFactory.intTypeInfo) || typeInfo.equals(TypeInfoFactory.shortTypeInfo) || typeInfo.equals(TypeInfoFactory.byteTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT32, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.longTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.longTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT64, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.doubleTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.doubleTypeInfo)) { return Types.primitive(PrimitiveTypeName.DOUBLE, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.floatTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.floatTypeInfo)) { return Types.primitive(PrimitiveTypeName.FLOAT, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.booleanTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.booleanTypeInfo)) { return Types.primitive(PrimitiveTypeName.BOOLEAN, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.binaryTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.binaryTypeInfo)) { return Types.primitive(PrimitiveTypeName.BINARY, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.timestampTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.timestampTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT96, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.voidTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.voidTypeInfo)) { throw new UnsupportedOperationException("Void type not implemented"); } - else if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( + if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( serdeConstants.CHAR_TYPE_NAME)) { return Types.optional(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); } - else if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( + if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( serdeConstants.VARCHAR_TYPE_NAME)) { return Types.optional(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); } - else if (typeInfo instanceof DecimalTypeInfo) { + if (typeInfo instanceof DecimalTypeInfo) { DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; int prec = decimalTypeInfo.precision(); int scale = decimalTypeInfo.scale(); int bytes = ParquetHiveSerDe.PRECISION_TO_BYTE_COUNT[prec - 1]; return Types.optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY).length(bytes).as(LogicalTypeAnnotation.decimalType(scale, prec)).named(name); } - else if (typeInfo.equals(TypeInfoFactory.dateTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.dateTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT32, repetition).as(LogicalTypeAnnotation.dateType()).named(name); } - else if (typeInfo.equals(TypeInfoFactory.unknownTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.unknownTypeInfo)) { throw new UnsupportedOperationException("Unknown type not implemented"); } - else { - throw new IllegalArgumentException("Unknown type: " + typeInfo); - } + throw new IllegalArgumentException("Unknown type: " + typeInfo); } - else if (typeInfo.getCategory() == Category.LIST) { + if (typeInfo.getCategory() == Category.LIST) { return convertArrayType(name, (ListTypeInfo) typeInfo); } - else if (typeInfo.getCategory() == Category.STRUCT) { + if (typeInfo.getCategory() == Category.STRUCT) { return convertStructType(name, (StructTypeInfo) typeInfo); } - else if (typeInfo.getCategory() == Category.MAP) { + if (typeInfo.getCategory() == Category.MAP) { return convertMapType(name, (MapTypeInfo) typeInfo); } - else if (typeInfo.getCategory() == Category.UNION) { + if (typeInfo.getCategory() == Category.UNION) { throw new UnsupportedOperationException("Union type not implemented"); } - else { - throw new IllegalArgumentException("Unknown type: " + typeInfo); - } + throw new IllegalArgumentException("Unknown type: " + typeInfo); } // An optional group containing a repeated anonymous group "bag", containing @@ -183,19 +179,17 @@ public static GroupType mapType(Repetition repetition, String alias, String mapA mapAlias, keyType)); } - else { - if (!valueType.getName().equals("value")) { - throw new RuntimeException(valueType.getName() + " should be value"); - } - return mapKvWrapper( - repetition, - alias, - new GroupType( - Repetition.REPEATED, - mapAlias, - keyType, - valueType)); + if (!valueType.getName().equals("value")) { + throw new RuntimeException(valueType.getName() + " should be value"); } + return mapKvWrapper( + repetition, + alias, + new GroupType( + Repetition.REPEATED, + mapAlias, + keyType, + valueType)); } private static GroupType mapKvWrapper(Repetition repetition, String alias, Type nested) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/SingleLevelArraySchemaConverter.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/SingleLevelArraySchemaConverter.java index f508b4dfb290..1b6f44d901b1 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/SingleLevelArraySchemaConverter.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/write/SingleLevelArraySchemaConverter.java @@ -74,51 +74,47 @@ private static Type convertType(String name, TypeInfo typeInfo, return Types.primitive(PrimitiveTypeName.BINARY, repetition).as(LogicalTypeAnnotation.stringType()) .named(name); } - else if (typeInfo.equals(TypeInfoFactory.intTypeInfo) || + if (typeInfo.equals(TypeInfoFactory.intTypeInfo) || typeInfo.equals(TypeInfoFactory.shortTypeInfo) || typeInfo.equals(TypeInfoFactory.byteTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT32, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.longTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.longTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT64, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.doubleTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.doubleTypeInfo)) { return Types.primitive(PrimitiveTypeName.DOUBLE, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.floatTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.floatTypeInfo)) { return Types.primitive(PrimitiveTypeName.FLOAT, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.booleanTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.booleanTypeInfo)) { return Types.primitive(PrimitiveTypeName.BOOLEAN, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.binaryTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.binaryTypeInfo)) { return Types.primitive(PrimitiveTypeName.BINARY, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.timestampTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.timestampTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT96, repetition).named(name); } - else if (typeInfo.equals(TypeInfoFactory.voidTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.voidTypeInfo)) { throw new UnsupportedOperationException("Void type not implemented"); } - else if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( + if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( serdeConstants.CHAR_TYPE_NAME)) { if (repetition == Repetition.OPTIONAL) { return Types.optional(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); } - else { - return Types.repeated(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); - } + return Types.repeated(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); } - else if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( + if (typeInfo.getTypeName().toLowerCase(Locale.ENGLISH).startsWith( serdeConstants.VARCHAR_TYPE_NAME)) { if (repetition == Repetition.OPTIONAL) { return Types.optional(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); } - else { - return Types.repeated(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); - } + return Types.repeated(PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()).named(name); } - else if (typeInfo instanceof DecimalTypeInfo) { + if (typeInfo instanceof DecimalTypeInfo) { DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; int prec = decimalTypeInfo.precision(); int scale = decimalTypeInfo.scale(); @@ -126,35 +122,29 @@ else if (typeInfo instanceof DecimalTypeInfo) { if (repetition == Repetition.OPTIONAL) { return Types.optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY).length(bytes).as(LogicalTypeAnnotation.decimalType(scale, prec)).named(name); } - else { - return Types.repeated(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY).length(bytes).as(LogicalTypeAnnotation.decimalType(scale, prec)).named(name); - } + return Types.repeated(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY).length(bytes).as(LogicalTypeAnnotation.decimalType(scale, prec)).named(name); } - else if (typeInfo.equals(TypeInfoFactory.dateTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.dateTypeInfo)) { return Types.primitive(PrimitiveTypeName.INT32, repetition).as(LogicalTypeAnnotation.dateType()).named(name); } - else if (typeInfo.equals(TypeInfoFactory.unknownTypeInfo)) { + if (typeInfo.equals(TypeInfoFactory.unknownTypeInfo)) { throw new UnsupportedOperationException("Unknown type not implemented"); } - else { - throw new IllegalArgumentException("Unknown type: " + typeInfo); - } + throw new IllegalArgumentException("Unknown type: " + typeInfo); } - else if (typeInfo.getCategory() == Category.LIST) { + if (typeInfo.getCategory() == Category.LIST) { return convertArrayType(name, (ListTypeInfo) typeInfo, repetition); } - else if (typeInfo.getCategory() == Category.STRUCT) { + if (typeInfo.getCategory() == Category.STRUCT) { return convertStructType(name, (StructTypeInfo) typeInfo, repetition); } - else if (typeInfo.getCategory() == Category.MAP) { + if (typeInfo.getCategory() == Category.MAP) { return convertMapType(name, (MapTypeInfo) typeInfo, repetition); } - else if (typeInfo.getCategory() == Category.UNION) { + if (typeInfo.getCategory() == Category.UNION) { throw new UnsupportedOperationException("Union type not implemented"); } - else { - throw new IllegalArgumentException("Unknown type: " + typeInfo); - } + throw new IllegalArgumentException("Unknown type: " + typeInfo); } // 1 anonymous element "array_element" diff --git a/plugin/trino-http-event-listener/pom.xml b/plugin/trino-http-event-listener/pom.xml index acc55b234087..f53cda967d1d 100644 --- a/plugin/trino-http-event-listener/pom.xml +++ b/plugin/trino-http-event-listener/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index b0c9d6745c85..43661692af98 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAnalyzeProperties.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAnalyzeProperties.java new file mode 100644 index 000000000000..a7a231c19a4f --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAnalyzeProperties.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.session.PropertyMetadata; +import io.trino.spi.type.ArrayType; + +import javax.inject.Inject; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.spi.StandardErrorCode.INVALID_ANALYZE_PROPERTY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.lang.String.format; + +public class IcebergAnalyzeProperties +{ + public static final String COLUMNS_PROPERTY = "columns"; + + private final List> analyzeProperties; + + @Inject + public IcebergAnalyzeProperties() + { + analyzeProperties = ImmutableList.>builder() + .add(new PropertyMetadata<>( + COLUMNS_PROPERTY, + "Columns to be analyzed", + new ArrayType(VARCHAR), + Set.class, + null, + false, + IcebergAnalyzeProperties::decodeColumnNames, + value -> value)) + .build(); + } + + public List> getAnalyzeProperties() + { + return analyzeProperties; + } + + public static Optional> getColumnNames(Map properties) + { + @SuppressWarnings("unchecked") + Set columns = (Set) properties.get(COLUMNS_PROPERTY); + return Optional.ofNullable(columns); + } + + private static Set decodeColumnNames(Object object) + { + if (object == null) { + return null; + } + + Collection columns = ((Collection) object); + return columns.stream() + .peek(property -> throwIfNull(property, "columns")) + .map(String.class::cast) + .collect(toImmutableSet()); + } + + private static void throwIfNull(Object object, String propertyName) + { + if (object == null) { + throw new TrinoException(INVALID_ANALYZE_PROPERTY, format("Invalid null value in analyze %s property", propertyName)); + } + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java index afc36cb11c2c..42d385699d10 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java @@ -39,6 +39,8 @@ public class IcebergConfig { public static final int FORMAT_VERSION_SUPPORT_MIN = 1; public static final int FORMAT_VERSION_SUPPORT_MAX = 2; + public static final String EXTENDED_STATISTICS_CONFIG = "iceberg.experimental.extended-statistics.enabled"; + public static final String EXTENDED_STATISTICS_DESCRIPTION = "Allow ANALYZE and use of extended statistics collected by it. Currently, the statistics are collected in Trino-specific format"; public static final String EXPIRE_SNAPSHOTS_MIN_RETENTION = "iceberg.expire_snapshots.min-retention"; public static final String REMOVE_ORPHAN_FILES_MIN_RETENTION = "iceberg.remove_orphan_files.min-retention"; @@ -50,6 +52,7 @@ public class IcebergConfig private CatalogType catalogType = HIVE_METASTORE; private Duration dynamicFilteringWaitTimeout = new Duration(0, SECONDS); private boolean tableStatisticsEnabled = true; + private boolean extendedStatisticsEnabled; private boolean projectionPushdownEnabled = true; private Optional hiveCatalogName = Optional.empty(); private int formatVersion = FORMAT_VERSION_SUPPORT_MAX; @@ -164,6 +167,11 @@ public IcebergConfig setDynamicFilteringWaitTimeout(Duration dynamicFilteringWai return this; } + public boolean isTableStatisticsEnabled() + { + return tableStatisticsEnabled; + } + // In case of some queries / tables, retrieving table statistics from Iceberg // can take 20+ seconds. This config allows the user / operator the option // to opt out of retrieving table statistics in those cases to speed up query planning. @@ -175,9 +183,17 @@ public IcebergConfig setTableStatisticsEnabled(boolean tableStatisticsEnabled) return this; } - public boolean isTableStatisticsEnabled() + public boolean isExtendedStatisticsEnabled() { - return tableStatisticsEnabled; + return extendedStatisticsEnabled; + } + + @Config(EXTENDED_STATISTICS_CONFIG) + @ConfigDescription(EXTENDED_STATISTICS_DESCRIPTION) + public IcebergConfig setExtendedStatisticsEnabled(boolean extendedStatisticsEnabled) + { + this.extendedStatisticsEnabled = extendedStatisticsEnabled; + return this; } public boolean isProjectionPushdownEnabled() diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java index 81846e020557..76d2b86bf6d4 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java @@ -58,6 +58,7 @@ public class IcebergConnector private final List> schemaProperties; private final List> tableProperties; private final List> materializedViewProperties; + private final List> analyzeProperties; private final Optional accessControl; private final Set procedures; private final Set tableProcedures; @@ -73,6 +74,7 @@ public IcebergConnector( List> schemaProperties, List> tableProperties, List> materializedViewProperties, + List> analyzeProperties, Optional accessControl, Set procedures, Set tableProcedures) @@ -89,6 +91,7 @@ public IcebergConnector( this.schemaProperties = ImmutableList.copyOf(requireNonNull(schemaProperties, "schemaProperties is null")); this.tableProperties = ImmutableList.copyOf(requireNonNull(tableProperties, "tableProperties is null")); this.materializedViewProperties = ImmutableList.copyOf(requireNonNull(materializedViewProperties, "materializedViewProperties is null")); + this.analyzeProperties = ImmutableList.copyOf(requireNonNull(analyzeProperties, "analyzeProperties is null")); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); this.tableProcedures = ImmutableSet.copyOf(requireNonNull(tableProcedures, "tableProcedures is null")); @@ -167,6 +170,12 @@ public List> getMaterializedViewProperties() return materializedViewProperties; } + @Override + public List> getAnalyzeProperties() + { + return analyzeProperties; + } + @Override public ConnectorAccessControl getAccessControl() { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index 872f65cfd00b..1ae1a18d5339 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -34,6 +35,7 @@ import io.trino.plugin.hive.HiveApplyProjectionUtil.ProjectedColumnRepresentation; import io.trino.plugin.hive.HiveWrittenPartitions; import io.trino.plugin.iceberg.catalog.TrinoCatalog; +import io.trino.plugin.iceberg.procedure.IcebergDropExtendedStatsHandle; import io.trino.plugin.iceberg.procedure.IcebergExpireSnapshotsHandle; import io.trino.plugin.iceberg.procedure.IcebergOptimizeHandle; import io.trino.plugin.iceberg.procedure.IcebergRemoveOrphanFilesHandle; @@ -41,12 +43,14 @@ import io.trino.plugin.iceberg.procedure.IcebergTableProcedureId; import io.trino.plugin.iceberg.util.DataFileWithDeleteFiles; import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; import io.trino.spi.connector.Assignment; import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.connector.ConnectorAnalyzeMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; import io.trino.spi.connector.ConnectorMergeTableHandle; @@ -81,8 +85,10 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.TrinoPrincipal; +import io.trino.spi.statistics.ColumnStatisticMetadata; import io.trino.spi.statistics.ComputedStatistics; import io.trino.spi.statistics.TableStatistics; +import io.trino.spi.statistics.TableStatisticsMetadata; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.TypeManager; @@ -117,6 +123,7 @@ import org.apache.iceberg.expressions.Term; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.IntegerType; import org.apache.iceberg.types.Types.NestedField; import org.apache.iceberg.types.Types.StringType; @@ -147,6 +154,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -162,6 +170,7 @@ import static io.trino.plugin.hive.util.HiveUtil.isStructuralType; import static io.trino.plugin.iceberg.ConstraintExtractor.extractTupleDomain; import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression; +import static io.trino.plugin.iceberg.IcebergAnalyzeProperties.getColumnNames; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_DATA; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_SPEC_ID; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_ROW_ID; @@ -180,6 +189,7 @@ import static io.trino.plugin.iceberg.IcebergSessionProperties.getExpireSnapshotMinRetention; import static io.trino.plugin.iceberg.IcebergSessionProperties.getRemoveOrphanFilesMinRetention; import static io.trino.plugin.iceberg.IcebergSessionProperties.isAllowLegacySnapshotSyntax; +import static io.trino.plugin.iceberg.IcebergSessionProperties.isExtendedStatisticsEnabled; import static io.trino.plugin.iceberg.IcebergSessionProperties.isProjectionPushdownEnabled; import static io.trino.plugin.iceberg.IcebergSessionProperties.isStatisticsEnabled; import static io.trino.plugin.iceberg.IcebergTableProperties.FILE_FORMAT_PROPERTY; @@ -199,18 +209,25 @@ import static io.trino.plugin.iceberg.IcebergUtil.schemaFromMetadata; import static io.trino.plugin.iceberg.PartitionFields.parsePartitionFields; import static io.trino.plugin.iceberg.PartitionFields.toPartitionFields; +import static io.trino.plugin.iceberg.TableStatisticsMaker.TRINO_STATS_COLUMN_ID_PATTERN; +import static io.trino.plugin.iceberg.TableStatisticsMaker.TRINO_STATS_NDV_FORMAT; +import static io.trino.plugin.iceberg.TableStatisticsMaker.TRINO_STATS_PREFIX; import static io.trino.plugin.iceberg.TableType.DATA; import static io.trino.plugin.iceberg.TypeConverter.toIcebergType; import static io.trino.plugin.iceberg.TypeConverter.toTrinoType; import static io.trino.plugin.iceberg.catalog.hms.TrinoHiveCatalog.DEPENDS_ON_TABLES; +import static io.trino.plugin.iceberg.procedure.IcebergTableProcedureId.DROP_EXTENDED_STATS; import static io.trino.plugin.iceberg.procedure.IcebergTableProcedureId.EXPIRE_SNAPSHOTS; import static io.trino.plugin.iceberg.procedure.IcebergTableProcedureId.OPTIMIZE; import static io.trino.plugin.iceberg.procedure.IcebergTableProcedureId.REMOVE_ORPHAN_FILES; import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static io.trino.spi.StandardErrorCode.INVALID_ANALYZE_PROPERTY; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RetryMode.NO_RETRIES; import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; +import static io.trino.spi.predicate.Utils.blockToNativeValue; +import static io.trino.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.UuidType.UUID; @@ -888,6 +905,8 @@ public Optional getTableHandleForExecute( switch (procedureId) { case OPTIMIZE: return getTableHandleForOptimize(session, tableHandle, executeProperties, retryMode); + case DROP_EXTENDED_STATS: + return getTableHandleForDropExtendedStats(session, tableHandle); case EXPIRE_SNAPSHOTS: return getTableHandleForExpireSnapshots(session, tableHandle, executeProperties); case REMOVE_ORPHAN_FILES: @@ -917,6 +936,17 @@ private Optional getTableHandleForOptimize(Connecto icebergTable.location())); } + private Optional getTableHandleForDropExtendedStats(ConnectorSession session, IcebergTableHandle tableHandle) + { + Table icebergTable = catalog.loadTable(session, tableHandle.getSchemaTableName()); + + return Optional.of(new IcebergTableExecuteHandle( + tableHandle.getSchemaTableName(), + DROP_EXTENDED_STATS, + new IcebergDropExtendedStatsHandle(), + icebergTable.location())); + } + private Optional getTableHandleForExpireSnapshots(ConnectorSession session, IcebergTableHandle tableHandle, Map executeProperties) { Duration retentionThreshold = (Duration) executeProperties.get(RETENTION_THRESHOLD); @@ -948,6 +978,7 @@ public Optional getLayoutForTableExecute(ConnectorSession switch (executeHandle.getProcedureId()) { case OPTIMIZE: return getLayoutForOptimize(session, executeHandle); + case DROP_EXTENDED_STATS: case EXPIRE_SNAPSHOTS: case REMOVE_ORPHAN_FILES: // handled via executeTableExecute @@ -974,6 +1005,7 @@ public BeginTableExecuteResult getColumnMetadatas(Schema schema) return columns.build(); } + @Override + public ConnectorAnalyzeMetadata getStatisticsCollectionMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, Map analyzeProperties) + { + if (!isExtendedStatisticsEnabled(session)) { + throw new TrinoException(NOT_SUPPORTED, "Analyze is not enabled. You can enable analyze using %s config or %s catalog session property".formatted( + IcebergConfig.EXTENDED_STATISTICS_CONFIG, + IcebergSessionProperties.EXTENDED_STATISTICS_ENABLED)); + } + + IcebergTableHandle handle = (IcebergTableHandle) tableHandle; + checkArgument(handle.getTableType() == DATA, "Cannot analyze non-DATA table: %s", handle.getTableType()); + Table icebergTable = catalog.loadTable(session, handle.getSchemaTableName()); + if (handle.getSnapshotId().isPresent() && (handle.getSnapshotId().get() != icebergTable.currentSnapshot().snapshotId())) { + throw new TrinoException(NOT_SUPPORTED, "Cannot analyze old snapshot %s".formatted(handle.getSnapshotId().get())); + } + + ConnectorTableMetadata tableMetadata = getTableMetadata(session, handle); + Set allDataColumnNames = tableMetadata.getColumns().stream() + .filter(column -> !column.isHidden()) + .map(ColumnMetadata::getName) + .collect(toImmutableSet()); + + Set analyzeColumnNames = getColumnNames(analyzeProperties) + .map(columnNames -> { + // validate that proper column names are passed via `columns` analyze property + if (columnNames.isEmpty()) { + throw new TrinoException(INVALID_ANALYZE_PROPERTY, "Cannot specify empty list of columns for analysis"); + } + if (!allDataColumnNames.containsAll(columnNames)) { + throw new TrinoException( + INVALID_ANALYZE_PROPERTY, + format("Invalid columns specified for analysis: %s", Sets.difference(columnNames, allDataColumnNames))); + } + return columnNames; + }) + .orElse(allDataColumnNames); + + Set columnStatistics = tableMetadata.getColumns().stream() + .filter(column -> analyzeColumnNames.contains(column.getName())) + // TODO: add support for NDV summary/sketch, but using Theta sketch, not HLL; see https://github.com/apache/iceberg-docs/pull/69 + .map(column -> new ColumnStatisticMetadata(column.getName(), NUMBER_OF_DISTINCT_VALUES)) + .collect(toImmutableSet()); + + return new ConnectorAnalyzeMetadata( + tableHandle, + new TableStatisticsMetadata(columnStatistics, ImmutableSet.of(), ImmutableList.of())); + } + + @Override + public ConnectorTableHandle beginStatisticsCollection(ConnectorSession session, ConnectorTableHandle tableHandle) + { + IcebergTableHandle handle = (IcebergTableHandle) tableHandle; + Table icebergTable = catalog.loadTable(session, handle.getSchemaTableName()); + beginTransaction(icebergTable); + return handle; + } + + @Override + public void finishStatisticsCollection(ConnectorSession session, ConnectorTableHandle tableHandle, Collection computedStatistics) + { + UpdateProperties updateProperties = transaction.updateProperties(); + Map columnNameToId = transaction.table().schema().columns().stream() + .collect(toImmutableMap(Types.NestedField::name, Types.NestedField::fieldId)); + Set columnIds = columnNameToId.values().stream() + .collect(toImmutableSet()); + + // Drop stats for obsolete columns + transaction.table().properties().keySet().stream() + .filter(key -> { + if (!key.startsWith(TRINO_STATS_PREFIX)) { + return false; + } + Matcher matcher = TRINO_STATS_COLUMN_ID_PATTERN.matcher(key); + if (!matcher.matches()) { + return false; + } + return !columnIds.contains(Integer.parseInt(matcher.group("columnId"))); + }) + .forEach(updateProperties::remove); + + for (ComputedStatistics computedStatistic : computedStatistics) { + verify(computedStatistic.getGroupingColumns().isEmpty() && computedStatistic.getGroupingValues().isEmpty(), "Unexpected grouping"); + verify(computedStatistic.getTableStatistics().isEmpty(), "Unexpected table statistics"); + for (Map.Entry entry : computedStatistic.getColumnStatistics().entrySet()) { + ColumnStatisticMetadata statisticMetadata = entry.getKey(); + if (statisticMetadata.getStatisticType() == NUMBER_OF_DISTINCT_VALUES) { + long ndv = (long) blockToNativeValue(BIGINT, entry.getValue()); + Integer columnId = verifyNotNull( + columnNameToId.get(statisticMetadata.getColumnName()), + "Column not found in table: [%s]", + statisticMetadata.getColumnName()); + updateProperties.set(TRINO_STATS_NDV_FORMAT.formatted(columnId), Long.toString(ndv)); + } + else { + throw new UnsupportedOperationException("Unsupported statistic type: " + statisticMetadata.getStatisticType()); + } + } + } + + updateProperties.commit(); + transaction.commitTransaction(); + transaction = null; + } + @Override public Optional applyDelete(ConnectorSession session, ConnectorTableHandle handle) { @@ -1924,7 +2081,7 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab IcebergTableHandle handle = (IcebergTableHandle) tableHandle; Table icebergTable = catalog.loadTable(session, handle.getSchemaTableName()); - return TableStatisticsMaker.getTableStatistics(typeManager, handle, icebergTable); + return TableStatisticsMaker.getTableStatistics(typeManager, session, handle, icebergTable); } @Override diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java index 0269f657a26b..def2527db491 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java @@ -25,6 +25,7 @@ import io.trino.plugin.hive.orc.OrcWriterConfig; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hive.parquet.ParquetWriterConfig; +import io.trino.plugin.iceberg.procedure.DropExtendedStatsTableProcedure; import io.trino.plugin.iceberg.procedure.ExpireSnapshotsTableProcedure; import io.trino.plugin.iceberg.procedure.OptimizeTableProcedure; import io.trino.plugin.iceberg.procedure.RemoveOrphanFilesTableProcedure; @@ -54,6 +55,7 @@ public void configure(Binder binder) newSetBinder(binder, SessionPropertiesProvider.class).addBinding().to(IcebergSessionProperties.class).in(Scopes.SINGLETON); binder.bind(IcebergTableProperties.class).in(Scopes.SINGLETON); binder.bind(IcebergMaterializedViewAdditionalProperties.class).in(Scopes.SINGLETON); + binder.bind(IcebergAnalyzeProperties.class).in(Scopes.SINGLETON); binder.bind(ConnectorSplitManager.class).to(IcebergSplitManager.class).in(Scopes.SINGLETON); newOptionalBinder(binder, ConnectorPageSourceProvider.class).setDefault().to(IcebergPageSourceProvider.class).in(Scopes.SINGLETON); @@ -81,6 +83,7 @@ public void configure(Binder binder) Multibinder tableProcedures = newSetBinder(binder, TableProcedureMetadata.class); tableProcedures.addBinding().toProvider(OptimizeTableProcedure.class).in(Scopes.SINGLETON); + tableProcedures.addBinding().toProvider(DropExtendedStatsTableProcedure.class).in(Scopes.SINGLETON); tableProcedures.addBinding().toProvider(ExpireSnapshotsTableProcedure.class).in(Scopes.SINGLETON); tableProcedures.addBinding().toProvider(RemoveOrphanFilesTableProcedure.class).in(Scopes.SINGLETON); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java index fc68120d5560..212aee409bab 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java @@ -123,6 +123,7 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa optimizeHandle.getFileFormat(), optimizeHandle.getTableStorageProperties(), maxOpenPartitions); + case DROP_EXTENDED_STATS: case EXPIRE_SNAPSHOTS: case REMOVE_ORPHAN_FILES: // handled via ConnectorMetadata.executeTableExecute diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index baa49f2c0e14..5dcf5d6cbfa6 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -175,7 +175,6 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; import static io.trino.spi.type.UuidType.UUID; -import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -820,13 +819,6 @@ private static Integer getIcebergFieldId(OrcColumn column) private static Type getOrcReadType(Type columnType, TypeManager typeManager) { - if (columnType == UUID) { - // ORC spec doesn't have UUID - // TODO read into Int128ArrayBlock for better performance when operating on read values - // TODO: Validate that the OrcColumn attribute ICEBERG_BINARY_TYPE is equal to "UUID" - return VARBINARY; - } - if (columnType instanceof ArrayType) { return new ArrayType(getOrcReadType(((ArrayType) columnType).getElementType(), typeManager)); } @@ -1058,7 +1050,8 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { dataSource, UTC, memoryContext, - options); + options, + Optional.empty()); return new ReaderPageSourceWithRowPositions( new ReaderPageSource( diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java index 40b9b1027f59..2d4fb4f93563 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java @@ -36,6 +36,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.plugin.base.session.PropertyMetadataUtil.dataSizeProperty; import static io.trino.plugin.base.session.PropertyMetadataUtil.durationProperty; +import static io.trino.plugin.iceberg.IcebergConfig.EXTENDED_STATISTICS_DESCRIPTION; import static io.trino.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static io.trino.spi.session.PropertyMetadata.booleanProperty; import static io.trino.spi.session.PropertyMetadata.doubleProperty; @@ -70,6 +71,7 @@ public final class IcebergSessionProperties private static final String PARQUET_WRITER_BATCH_SIZE = "parquet_writer_batch_size"; private static final String DYNAMIC_FILTERING_WAIT_TIMEOUT = "dynamic_filtering_wait_timeout"; private static final String STATISTICS_ENABLED = "statistics_enabled"; + public static final String EXTENDED_STATISTICS_ENABLED = "experimental_extended_statistics_enabled"; private static final String PROJECTION_PUSHDOWN_ENABLED = "projection_pushdown_enabled"; private static final String TARGET_MAX_FILE_SIZE = "target_max_file_size"; private static final String HIVE_CATALOG_NAME = "hive_catalog_name"; @@ -214,6 +216,11 @@ public IcebergSessionProperties( "Expose table statistics", icebergConfig.isTableStatisticsEnabled(), false)) + .add(booleanProperty( + EXTENDED_STATISTICS_ENABLED, + EXTENDED_STATISTICS_DESCRIPTION, + icebergConfig.isExtendedStatisticsEnabled(), + false)) .add(booleanProperty( PROJECTION_PUSHDOWN_ENABLED, "Read only required fields from a struct", @@ -382,6 +389,11 @@ public static boolean isStatisticsEnabled(ConnectorSession session) return session.getProperty(STATISTICS_ENABLED, Boolean.class); } + public static boolean isExtendedStatisticsEnabled(ConnectorSession session) + { + return session.getProperty(EXTENDED_STATISTICS_ENABLED, Boolean.class); + } + public static boolean isProjectionPushdownEnabled(ConnectorSession session) { return session.getProperty(PROJECTION_PUSHDOWN_ENABLED, Boolean.class); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableHandle.java index 9d3b38c2315b..8cd1d3f58299 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableHandle.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergTableHandle.java @@ -33,6 +33,7 @@ import java.util.Set; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class IcebergTableHandle implements ConnectorTableHandle @@ -381,6 +382,17 @@ public int hashCode() @Override public String toString() { - return getSchemaTableNameWithType() + snapshotId.map(v -> "@" + v).orElse(""); + StringBuilder builder = new StringBuilder(getSchemaTableNameWithType().toString()); + snapshotId.ifPresent(snapshotId -> builder.append("@").append(snapshotId)); + if (enforcedPredicate.isNone()) { + builder.append(" constraint=FALSE"); + } + else if (!enforcedPredicate.isAll()) { + builder.append(" constraint on "); + builder.append(enforcedPredicate.getDomains().orElseThrow().keySet().stream() + .map(IcebergColumnHandle::getQualifiedName) + .collect(joining(", ", "[", "]"))); + } + return builder.toString(); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java index 6bd4d38c8bd3..52f980192183 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java @@ -118,6 +118,7 @@ public static Connector createConnector( Set sessionPropertiesProviders = injector.getInstance(Key.get(new TypeLiteral>() {})); IcebergTableProperties icebergTableProperties = injector.getInstance(IcebergTableProperties.class); IcebergMaterializedViewAdditionalProperties materializedViewAdditionalProperties = injector.getInstance(IcebergMaterializedViewAdditionalProperties.class); + IcebergAnalyzeProperties icebergAnalyzeProperties = injector.getInstance(IcebergAnalyzeProperties.class); Set procedures = injector.getInstance(Key.get(new TypeLiteral>() {})); Set tableProcedures = injector.getInstance(Key.get(new TypeLiteral>() {})); Optional accessControl = injector.getInstance(Key.get(new TypeLiteral>() {})); @@ -137,6 +138,7 @@ public static Connector createConnector( IcebergSchemaProperties.SCHEMA_PROPERTIES, icebergTableProperties.getTableProperties(), materializedViewProperties, + icebergAnalyzeProperties.getAnalyzeProperties(), accessControl, procedures, tableProcedures); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsMaker.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsMaker.java index be68d0b4116c..bd296d144063 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsMaker.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsMaker.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.statistics.ColumnStatistics; import io.trino.spi.statistics.DoubleRange; @@ -32,26 +33,37 @@ import java.io.UncheckedIOException; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression; +import static io.trino.plugin.iceberg.IcebergSessionProperties.isExtendedStatisticsEnabled; import static io.trino.plugin.iceberg.IcebergUtil.getColumns; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toUnmodifiableMap; public class TableStatisticsMaker { + public static final String TRINO_STATS_PREFIX = "trino.stats.ndv."; + public static final String TRINO_STATS_NDV_FORMAT = TRINO_STATS_PREFIX + "%d.ndv"; + public static final Pattern TRINO_STATS_COLUMN_ID_PATTERN = Pattern.compile(Pattern.quote(TRINO_STATS_PREFIX) + "(?\\d+)\\..*"); + public static final Pattern TRINO_STATS_NDV_PATTERN = Pattern.compile(Pattern.quote(TRINO_STATS_PREFIX) + "(?\\d+)\\.ndv"); + private final TypeManager typeManager; + private final ConnectorSession session; private final Table icebergTable; - private TableStatisticsMaker(TypeManager typeManager, Table icebergTable) + private TableStatisticsMaker(TypeManager typeManager, ConnectorSession session, Table icebergTable) { this.typeManager = typeManager; + this.session = session; this.icebergTable = icebergTable; } - public static TableStatistics getTableStatistics(TypeManager typeManager, IcebergTableHandle tableHandle, Table icebergTable) + public static TableStatistics getTableStatistics(TypeManager typeManager, ConnectorSession session, IcebergTableHandle tableHandle, Table icebergTable) { - return new TableStatisticsMaker(typeManager, icebergTable).makeTableStatistics(tableHandle); + return new TableStatisticsMaker(typeManager, session, icebergTable).makeTableStatistics(tableHandle); } private TableStatistics makeTableStatistics(IcebergTableHandle tableHandle) @@ -98,6 +110,8 @@ private TableStatistics makeTableStatistics(IcebergTableHandle tableHandle) .build(); } + Map ndvs = readNdvs(icebergTable); + ImmutableMap.Builder columnHandleBuilder = ImmutableMap.builder(); double recordCount = summary.getRecordCount(); for (IcebergColumnHandle columnHandle : idToColumnHandle.values()) { @@ -118,8 +132,32 @@ private TableStatistics makeTableStatistics(IcebergTableHandle tableHandle) if (min != null && max != null) { columnBuilder.setRange(DoubleRange.from(columnHandle.getType(), min, max)); } + columnBuilder.setDistinctValuesCount( + Optional.ofNullable(ndvs.get(fieldId)) + .map(Estimate::of) + .orElseGet(Estimate::unknown)); columnHandleBuilder.put(columnHandle, columnBuilder.build()); } return new TableStatistics(Estimate.of(recordCount), columnHandleBuilder.buildOrThrow()); } + + private Map readNdvs(Table icebergTable) + { + if (!isExtendedStatisticsEnabled(session)) { + return ImmutableMap.of(); + } + + ImmutableMap.Builder ndvByColumnId = ImmutableMap.builder(); + icebergTable.properties().forEach((key, value) -> { + if (key.startsWith(TRINO_STATS_PREFIX)) { + Matcher matcher = TRINO_STATS_NDV_PATTERN.matcher(key); + if (matcher.matches()) { + int columnId = Integer.parseInt(matcher.group("columnId")); + long ndv = Long.parseLong(value); + ndvByColumnId.put(columnId, ndv); + } + } + }); + return ndvByColumnId.buildOrThrow(); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/AbstractIcebergTableOperations.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/AbstractIcebergTableOperations.java index a52facba4472..64e601843ec9 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/AbstractIcebergTableOperations.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/AbstractIcebergTableOperations.java @@ -51,6 +51,7 @@ import static org.apache.iceberg.TableProperties.METADATA_COMPRESSION; import static org.apache.iceberg.TableProperties.METADATA_COMPRESSION_DEFAULT; import static org.apache.iceberg.TableProperties.WRITE_METADATA_LOCATION; +import static org.apache.iceberg.util.LocationUtil.stripTrailingSlash; @NotThreadSafe public abstract class AbstractIcebergTableOperations @@ -173,14 +174,14 @@ public String metadataFileLocation(String filename) if (metadata != null) { String writeLocation = metadata.properties().get(WRITE_METADATA_LOCATION); if (writeLocation != null) { - return format("%s/%s", writeLocation, filename); + return format("%s/%s", stripTrailingSlash(writeLocation), filename); } location = metadata.location(); } else { location = this.location.orElseThrow(() -> new IllegalStateException("Location not set")); } - return format("%s/%s/%s", location, METADATA_FOLDER_NAME, filename); + return format("%s/%s/%s", stripTrailingSlash(location), METADATA_FOLDER_NAME, filename); } @Override @@ -244,9 +245,9 @@ protected static String metadataFileLocation(TableMetadata metadata, String file { String location = metadata.properties().get(WRITE_METADATA_LOCATION); if (location != null) { - return format("%s/%s", location, filename); + return format("%s/%s", stripTrailingSlash(location), filename); } - return format("%s/%s/%s", metadata.location(), METADATA_FOLDER_NAME, filename); + return format("%s/%s/%s", stripTrailingSlash(metadata.location()), METADATA_FOLDER_NAME, filename); } protected static int parseVersion(String metadataLocation) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/glue/TrinoGlueCatalog.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/glue/TrinoGlueCatalog.java index 986613042d6c..e83d011b6127 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/glue/TrinoGlueCatalog.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/catalog/glue/TrinoGlueCatalog.java @@ -338,11 +338,9 @@ public void renameTable(ConnectorSession session, SchemaTableName from, SchemaTa { boolean newTableCreated = false; try { - Optional table = getTable(from); - if (table.isEmpty()) { - throw new TableNotFoundException(from); - } - TableInput tableInput = getTableInput(to.getTableName(), Optional.ofNullable(table.get().getOwner()), table.get().getParameters()); + com.amazonaws.services.glue.model.Table table = getTable(from) + .orElseThrow(() -> new TableNotFoundException(from)); + TableInput tableInput = getTableInput(to.getTableName(), Optional.ofNullable(table.getOwner()), table.getParameters()); CreateTableRequest createTableRequest = new CreateTableRequest() .withDatabaseName(to.getSchemaName()) .withTableInput(tableInput); @@ -478,15 +476,13 @@ public void renameView(ConnectorSession session, SchemaTableName source, SchemaT { boolean newTableCreated = false; try { - Optional existingView = getTable(source); - if (existingView.isEmpty()) { - throw new TableNotFoundException(source); - } + com.amazonaws.services.glue.model.Table existingView = getTable(source) + .orElseThrow(() -> new TableNotFoundException(source)); TableInput viewTableInput = getViewTableInput( target.getTableName(), - existingView.get().getViewOriginalText(), - existingView.get().getOwner(), + existingView.getViewOriginalText(), + existingView.getOwner(), createViewProperties(session)); CreateTableRequest createTableRequest = new CreateTableRequest() .withDatabaseName(target.getSchemaName()) @@ -770,11 +766,8 @@ public void renameMaterializedView(ConnectorSession session, SchemaTableName sou { boolean newTableCreated = false; try { - Optional table = getTable(source); - if (table.isEmpty()) { - throw new TableNotFoundException(source); - } - com.amazonaws.services.glue.model.Table glueTable = table.get(); + com.amazonaws.services.glue.model.Table glueTable = getTable(source) + .orElseThrow(() -> new TableNotFoundException(source)); if (!isTrinoMaterializedView(glueTable.getTableType(), glueTable.getParameters())) { throw new TrinoException(UNSUPPORTED_TABLE_TYPE, "Not a Materialized View: " + source); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/DropExtendedStatsTableProcedure.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/DropExtendedStatsTableProcedure.java new file mode 100644 index 000000000000..aa892a7c42c0 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/DropExtendedStatsTableProcedure.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg.procedure; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.connector.TableProcedureMetadata; + +import javax.inject.Provider; + +import static io.trino.plugin.iceberg.procedure.IcebergTableProcedureId.DROP_EXTENDED_STATS; +import static io.trino.spi.connector.TableProcedureExecutionMode.coordinatorOnly; + +public class DropExtendedStatsTableProcedure + implements Provider +{ + @Override + public TableProcedureMetadata get() + { + return new TableProcedureMetadata( + DROP_EXTENDED_STATS.name(), + coordinatorOnly(), + ImmutableList.of()); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelIterable.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergDropExtendedStatsHandle.java similarity index 65% rename from lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelIterable.java rename to plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergDropExtendedStatsHandle.java index 7dcc7cbca5bd..55bf7e092a7d 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelIterable.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergDropExtendedStatsHandle.java @@ -11,19 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.parquet.writer.repdef; +package io.trino.plugin.iceberg.procedure; -import com.google.common.collect.AbstractIterator; +import static com.google.common.base.MoreObjects.toStringHelper; -import java.util.OptionalInt; - -public interface DefLevelIterable +public class IcebergDropExtendedStatsHandle + extends IcebergProcedureHandle { - DefLevelIterator getIterator(); - - abstract class DefLevelIterator - extends AbstractIterator + @Override + public String toString() { - abstract boolean end(); + return toStringHelper(this) + .toString(); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergProcedureHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergProcedureHandle.java index 96e247926338..e9ce4199eeac 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergProcedureHandle.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergProcedureHandle.java @@ -21,6 +21,7 @@ property = "@type") @JsonSubTypes({ @JsonSubTypes.Type(value = IcebergOptimizeHandle.class, name = "optimize"), + @JsonSubTypes.Type(value = IcebergDropExtendedStatsHandle.class, name = "drop_extended_stats"), @JsonSubTypes.Type(value = IcebergExpireSnapshotsHandle.class, name = "expire_snapshots"), @JsonSubTypes.Type(value = IcebergRemoveOrphanFilesHandle.class, name = "remove_orphan_files"), }) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergTableProcedureId.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergTableProcedureId.java index e81c8336fb04..8b1c68fb23ed 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergTableProcedureId.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/IcebergTableProcedureId.java @@ -16,6 +16,7 @@ public enum IcebergTableProcedureId { OPTIMIZE, + DROP_EXTENDED_STATS, EXPIRE_SNAPSHOTS, REMOVE_ORPHAN_FILES, } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 3d4e6b19479a..80f37b650181 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -90,11 +90,13 @@ import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.trino.SystemSessionProperties.PREFERRED_WRITE_PARTITIONING_MIN_NUMBER_OF_PARTITIONS; import static io.trino.SystemSessionProperties.SCALE_WRITERS; +import static io.trino.SystemSessionProperties.TASK_WRITER_COUNT; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.iceberg.IcebergFileFormat.AVRO; import static io.trino.plugin.iceberg.IcebergFileFormat.ORC; import static io.trino.plugin.iceberg.IcebergFileFormat.PARQUET; import static io.trino.plugin.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; +import static io.trino.plugin.iceberg.IcebergSessionProperties.EXTENDED_STATISTICS_ENABLED; import static io.trino.plugin.iceberg.IcebergSplitManager.ICEBERG_DOMAIN_COMPACTION_THRESHOLD; import static io.trino.spi.predicate.Domain.multipleValues; import static io.trino.spi.predicate.Domain.singleValue; @@ -2858,6 +2860,69 @@ public void testBasicTableStatistics() dropTable(tableName); } + /** + * @see TestIcebergAnalyze + */ + @Test + public void testBasicAnalyze() + { + Session defaultSession = getSession(); + String catalog = defaultSession.getCatalog().orElseThrow(); + Session extendedStatisticsEnabled = Session.builder(defaultSession) + .setCatalogSessionProperty(catalog, EXTENDED_STATISTICS_ENABLED, "true") + .build(); + String tableName = "test_basic_analyze"; + + assertUpdate(defaultSession, "CREATE TABLE " + tableName + " AS SELECT * FROM tpch.tiny.region", 5); + + String statsWithoutNdv = format == AVRO + ? ("VALUES " + + " ('regionkey', NULL, NULL, NULL, NULL, NULL, NULL), " + + " ('name', NULL, NULL, NULL, NULL, NULL, NULL), " + + " ('comment', NULL, NULL, NULL, NULL, NULL, NULL), " + + " (NULL, NULL, NULL, NULL, 5e0, NULL, NULL)") + : ("VALUES " + + " ('regionkey', NULL, NULL, 0e0, NULL, '0', '4'), " + + " ('name', " + (format == PARQUET ? "87e0" : "NULL") + ", NULL, 0e0, NULL, NULL, NULL), " + + " ('comment', " + (format == PARQUET ? "237e0" : "NULL") + ", NULL, 0e0, NULL, NULL, NULL), " + + " (NULL, NULL, NULL, NULL, 5e0, NULL, NULL)"); + + String statsWithNdv = format == AVRO + ? ("VALUES " + + " ('regionkey', NULL, 5e0, NULL, NULL, NULL, NULL), " + + " ('name', NULL, 5e0, NULL, NULL, NULL, NULL), " + + " ('comment', NULL, 5e0, NULL, NULL, NULL, NULL), " + + " (NULL, NULL, NULL, NULL, 5e0, NULL, NULL)") + : ("VALUES " + + " ('regionkey', NULL, 5e0, 0e0, NULL, '0', '4'), " + + " ('name', " + (format == PARQUET ? "87e0" : "NULL") + ", 5e0, 0e0, NULL, NULL, NULL), " + + " ('comment', " + (format == PARQUET ? "237e0" : "NULL") + ", 5e0, 0e0, NULL, NULL, NULL), " + + " (NULL, NULL, NULL, NULL, 5e0, NULL, NULL)"); + + // initially, no NDV information + assertThat(query(defaultSession, "SHOW STATS FOR " + tableName)).skippingTypesCheck().matches(statsWithoutNdv); + assertThat(query(extendedStatisticsEnabled, "SHOW STATS FOR " + tableName)).skippingTypesCheck().matches(statsWithoutNdv); + + // ANALYZE needs to be enabled. This is because it currently stores additional statistics in a Trino-specific format and the format will change. + assertQueryFails( + defaultSession, + "ANALYZE " + tableName, + "\\QAnalyze is not enabled. You can enable analyze using iceberg.experimental.extended-statistics.enabled config or experimental_extended_statistics_enabled catalog session property"); + + // ANALYZE the table + assertUpdate(extendedStatisticsEnabled, "ANALYZE " + tableName); + // After ANALYZE, NDV information present + assertThat(query(extendedStatisticsEnabled, "SHOW STATS FOR " + tableName)) + .skippingTypesCheck() + .matches(statsWithNdv); + // NDV information is not present in a session with extended statistics not enabled + assertThat(query(defaultSession, "SHOW STATS FOR " + tableName)) + .skippingTypesCheck() + .matches(statsWithoutNdv); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testMultipleColumnTableStatistics() { @@ -5314,6 +5379,24 @@ public void testInsertingIntoTablesWithColumnsWithQuotesInName() assertUpdate("DROP TABLE " + tableName); } + @Test + public void testInsertIntoBucketedColumnTaskWriterCount() + { + int taskWriterCount = 4; + assertThat(taskWriterCount).isGreaterThan(getQueryRunner().getNodeCount()); + Session session = Session.builder(getSession()) + .setSystemProperty(TASK_WRITER_COUNT, String.valueOf(taskWriterCount)) + .build(); + + String tableName = "test_inserting_into_bucketed_column_task_writer_count_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (x INT) WITH (partitioning = ARRAY['bucket(x, 7)'])"); + + assertUpdate(session, "INSERT INTO " + tableName + " SELECT nationkey FROM nation", 25); + assertQuery("SELECT * FROM " + tableName, "SELECT nationkey FROM nation"); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testReadFromVersionedTableWithSchemaEvolution() { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergMinioConnectorSmokeTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergMinioConnectorSmokeTest.java index 25c036be7413..adb65faa5095 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergMinioConnectorSmokeTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergMinioConnectorSmokeTest.java @@ -19,6 +19,7 @@ import org.apache.iceberg.FileFormat; import org.testng.annotations.Test; +import java.util.List; import java.util.Locale; import java.util.Map; @@ -26,6 +27,7 @@ import static io.trino.plugin.hive.containers.HiveMinioDataLake.MINIO_SECRET_KEY; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; public abstract class BaseIcebergMinioConnectorSmokeTest extends BaseIcebergConnectorSmokeTest @@ -85,4 +87,24 @@ public void testRenameSchema() format("ALTER SCHEMA %s RENAME TO %s", schemaName, schemaName + randomTableSuffix()), "Hive metastore does not support renaming schemas"); } + + @Test + public void testS3LocationWithTrailingSlash() + { + // Verify data and metadata files' uri don't contain fragments + String schemaName = getSession().getSchema().orElseThrow(); + String tableName = "test_s3_location_with_trailing_slash_" + randomTableSuffix(); + String location = "s3://%s/%s/%s/".formatted(bucketName, schemaName, tableName); + assertThat(location).doesNotContain("#"); + + assertUpdate("CREATE TABLE " + tableName + " WITH (location='" + location + "') AS SELECT 1 col", 1); + + List dataFiles = hiveMinioDataLake.getMinioClient().listObjects(bucketName, "/%s/%s/data".formatted(schemaName, tableName)); + assertThat(dataFiles).isNotEmpty().filteredOn(filePath -> filePath.contains("#")).isEmpty(); + + List metadataFiles = hiveMinioDataLake.getMinioClient().listObjects(bucketName, "/%s/%s/metadata".formatted(schemaName, tableName)); + assertThat(metadataFiles).isNotEmpty().filteredOn(filePath -> filePath.contains("#")).isEmpty(); + + assertUpdate("DROP TABLE " + tableName); + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergAnalyze.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergAnalyze.java new file mode 100644 index 000000000000..aee239d6cb33 --- /dev/null +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergAnalyze.java @@ -0,0 +1,589 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import io.trino.Session; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE_TABLE_PROCEDURE; +import static io.trino.testing.TestingAccessControlManager.privilege; +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static io.trino.tpch.TpchTable.NATION; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestIcebergAnalyze + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setIcebergProperties(Map.of("iceberg.experimental.extended-statistics.enabled", "true")) + .setInitialTables(NATION) + .build(); + } + + @Test + public void testAnalyze() + { + String tableName = "test_analyze"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, null, 0, null, '0', '24'), + ('regionkey', null, null, 0, null, '0', '4'), + ('comment', null, null, 0, null, null, null), + ('name', null, null, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + // reanalyze data + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + // insert one more copy; should not influence stats other than rowcount + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM tpch.sf1.nation", 25); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 50, null, null)"""); + + // insert modified rows + assertUpdate("INSERT INTO " + tableName + " SELECT nationkey + 25, reverse(name), regionkey + 5, reverse(comment) FROM tpch.sf1.nation", 25); + + // without ANALYZE all stats but NDV should be updated + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '49'), + ('regionkey', null, 5, 0, null, '0', '9'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 75, null, null)"""); + + // with analyze we should get new NDV + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 50, 0, null, '0', '49'), + ('regionkey', null, 10, 0, null, '0', '9'), + ('comment', null, 50, 0, null, null, null), + ('name', null, 50, 0, null, null, null), + (null, null, null, null, 75, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzeWithSchemaEvolution() + { + String tableName = "test_analyze_with_schema_evolution"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + + assertUpdate("ANALYZE " + tableName); + + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN info varchar"); + assertUpdate("UPDATE " + tableName + " SET info = format('%s %s', name, comment)", 25); + assertUpdate("ALTER TABLE " + tableName + " DROP COLUMN comment"); + + // schema changed, ANALYZE hasn't been re-run yet + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('name', null, 25, 0, null, null, null), + ('info', null, null, null, null, null, null), + (null, null, null, null, 50, null, null)"""); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('name', null, 25, 0, null, null, null), + ('info', null, 25, null, null, null, null), + (null, null, null, null, 50, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzePartitioned() + { + String tableName = "test_analyze_partitioned"; + assertUpdate("CREATE TABLE " + tableName + " WITH (partitioning = ARRAY['regionkey']) AS SELECT * FROM tpch.sf1.nation", 25); + + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, null, 0, null, '0', '24'), + ('regionkey', null, null, 0, null, '0', '4'), + ('comment', null, null, 0, null, null, null), + ('name', null, null, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + // insert one more copy; should not influence stats other than rowcount + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM tpch.sf1.nation", 25); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 50, null, null)"""); + + // insert modified rows + assertUpdate("INSERT INTO " + tableName + " SELECT nationkey + 25, reverse(name), regionkey + 5, reverse(comment) FROM tpch.sf1.nation", 25); + + // without ANALYZE all stats but NDV should be updated + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '49'), + ('regionkey', null, 5, 0, null, '0', '9'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 75, null, null)"""); + + // with analyze we should get new NDV + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 50, 0, null, '0', '49'), + ('regionkey', null, 10, 0, null, '0', '9'), + ('comment', null, 50, 0, null, null, null), + ('name', null, 50, 0, null, null, null), + (null, null, null, null, 75, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzeEmpty() + { + String tableName = "test_analyze_empty"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation WITH NO DATA", 0); + + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', 0, 0, 1, null, null, null), + ('regionkey', 0, 0, 1, null, null, null), + ('comment', 0, 0, 1, null, null, null), + ('name', 0, 0, 1, null, null, null), + (null, null, null, null, 0, null, null)"""); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', 0, 0, 1, null, null, null), + ('regionkey', 0, 0, 1, null, null, null), + ('comment', 0, 0, 1, null, null, null), + ('name', 0, 0, 1, null, null, null), + (null, null, null, null, 0, null, null)"""); + + // add some data and reanalyze + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM tpch.sf1.nation", 25); + + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzeSomeColumns() + { + String tableName = "test_analyze_some_columns"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + + // analyze NULL list of columns + assertQueryFails("ANALYZE " + tableName + " WITH (columns = NULL)", "\\QInvalid null value for catalog 'iceberg' analyze property 'columns' from [null]"); + + // analyze empty list of columns + assertQueryFails("ANALYZE " + tableName + " WITH (columns = ARRAY[])", "\\QCannot specify empty list of columns for analysis"); + + // specify non-existent column + assertQueryFails("ANALYZE " + tableName + " WITH (columns = ARRAY['nationkey', 'blah'])", "\\QInvalid columns specified for analysis: [blah]"); + + // specify column with wrong case + assertQueryFails("ANALYZE " + tableName + " WITH (columns = ARRAY['NationKey'])", "\\QInvalid columns specified for analysis: [NationKey]"); + + // specify NULL column + assertQueryFails( + "ANALYZE " + tableName + " WITH (columns = ARRAY['nationkey', NULL])", + "\\QUnable to set catalog 'iceberg' analyze property 'columns' to [ARRAY['nationkey',null]]: Invalid null value in analyze columns property"); + + // analyze nationkey and regionkey + assertUpdate("ANALYZE " + tableName + " WITH (columns = ARRAY['nationkey', 'regionkey'])"); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, null, 0, null, null, null), + ('name', null, null, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + // insert modified rows + assertUpdate("INSERT INTO " + tableName + " SELECT nationkey + 25, concat(name, '1'), regionkey + 5, concat(comment, '21') FROM tpch.sf1.nation", 25); + + // perform one more analyze for nationkey and regionkey + assertUpdate("ANALYZE " + tableName + " WITH (columns = ARRAY['nationkey', 'regionkey'])"); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 50, 0, null, '0', '49'), + ('regionkey', null, 10, 0, null, '0', '9'), + ('comment', null, null, 0, null, null, null), + ('name', null, null, 0, null, null, null), + (null, null, null, null, 50, null, null)"""); + + // drop stats + assertUpdate("ALTER TABLE " + tableName + " EXECUTE DROP_EXTENDED_STATS"); + + // analyze all columns + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 50, 0, null, '0', '49'), + ('regionkey', null, 10, 0, null, '0', '9'), + ('comment', null, 50, 0, null, null, null), + ('name', null, 50, 0, null, null, null), + (null, null, null, null, 50, null, null)"""); + + // insert modified rows + assertUpdate("INSERT INTO " + tableName + " SELECT nationkey + 50, concat(name, '2'), regionkey + 10, concat(comment, '22') FROM tpch.sf1.nation", 25); + + // without ANALYZE all stats but NDV should be updated + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 50, 0, null, '0', '74'), + ('regionkey', null, 10, 0, null, '0', '14'), + ('comment', null, 50, 0, null, null, null), + ('name', null, 50, 0, null, null, null), + (null, null, null, null, 75, null, null)"""); + + // reanalyze with a subset of columns + assertUpdate("ANALYZE " + tableName + " WITH (columns = ARRAY['nationkey', 'regionkey'])"); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 75, 0, null, '0', '74'), + ('regionkey', null, 15, 0, null, '0', '14'), + ('comment', null, 50, 0, null, null, null), -- result of previous analyze + ('name', null, 50, 0, null, null, null), -- result of previous analyze + (null, null, null, null, 75, null, null)"""); + + // analyze all columns + assertUpdate("ANALYZE " + tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 75, 0, null, '0', '74'), + ('regionkey', null, 15, 0, null, '0', '14'), + ('comment', null, 75, 0, null, null, null), + ('name', null, 75, 0, null, null, null), + (null, null, null, null, 75, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzeSnapshot() + { + Session session = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "allow_legacy_snapshot_syntax", "true") + .build(); + String tableName = "test_analyze_snapshot_" + randomTableSuffix(); + + assertUpdate("CREATE TABLE " + tableName + " (a) AS VALUES 11", 1); + long snapshotId = getCurrentSnapshotId(tableName); + assertUpdate("INSERT INTO " + tableName + " VALUES 22", 1); + assertThatThrownBy(() -> query(session, "ANALYZE \"%s@%d\"".formatted(tableName, snapshotId))) + .hasMessage("Cannot analyze old snapshot " + snapshotId); + assertThat(query("SELECT * FROM " + tableName)) + .matches("VALUES 11, 22"); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzeSystemTable() + { + assertThatThrownBy(() -> query("ANALYZE \"nation$files\"")) + // The error message isn't clear to the user, but it doesn't matter + .hasMessage("Cannot record write for catalog not part of transaction"); + assertThatThrownBy(() -> query("ANALYZE \"nation$snapshots\"")) + // The error message isn't clear to the user, but it doesn't matter + .hasMessage("Cannot record write for catalog not part of transaction"); + } + + @Test + public void testDropExtendedStats() + { + String tableName = "test_drop_extended_stats"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + + String baseStats = """ + VALUES + ('nationkey', null, null, 0, null, '0', '24'), + ('regionkey', null, null, 0, null, '0', '4'), + ('comment', null, null, 0, null, null, null), + ('name', null, null, 0, null, null, null), + (null, null, null, null, 25, null, null)"""; + String extendedStats = """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""; + + assertQuery("SHOW STATS FOR " + tableName, baseStats); + + // Update stats to include distinct count + assertUpdate("ANALYZE " + tableName); + assertQuery("SHOW STATS FOR " + tableName, extendedStats); + + // Dropping extended stats clears distinct count and leaves other stats alone + assertUpdate("ALTER TABLE " + tableName + " EXECUTE DROP_EXTENDED_STATS"); + assertQuery("SHOW STATS FOR " + tableName, baseStats); + + // Re-analyzing should work + assertUpdate("ANALYZE " + tableName); + assertQuery("SHOW STATS FOR " + tableName, extendedStats); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testDropMissingStats() + { + String tableName = "test_drop_missing_stats"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + + // When there are no extended stats, the procedure should have no effect + assertUpdate("ALTER TABLE " + tableName + " EXECUTE DROP_EXTENDED_STATS"); + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, null, 0, null, '0', '24'), + ('regionkey', null, null, 0, null, '0', '4'), + ('comment', null, null, 0, null, null, null), + ('name', null, null, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testDropStatsAccessControl() + { + String catalog = getSession().getCatalog().orElseThrow(); + String schema = getSession().getSchema().orElseThrow(); + String tableName = "test_deny_drop_stats"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + + assertAccessDenied( + "ALTER TABLE " + tableName + " EXECUTE DROP_EXTENDED_STATS", + "Cannot execute table procedure DROP_EXTENDED_STATS on iceberg.tpch.test_deny_drop_stats", + privilege(format("%s.%s.%s.DROP_EXTENDED_STATS", catalog, schema, tableName), EXECUTE_TABLE_PROCEDURE)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testDropStatsSnapshot() + { + Session session = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "allow_legacy_snapshot_syntax", "true") + .build(); + String tableName = "test_drop_stats_snapshot_" + randomTableSuffix(); + + assertUpdate("CREATE TABLE " + tableName + " (a) AS VALUES 11", 1); + long snapshotId = getCurrentSnapshotId(tableName); + assertUpdate("INSERT INTO " + tableName + " VALUES 22", 1); + assertThatThrownBy(() -> query(session, "ALTER TABLE \"%s@%d\" EXECUTE DROP_EXTENDED_STATS".formatted(tableName, snapshotId))) + .hasMessage("Cannot execute table procedure DROP_EXTENDED_STATS on old snapshot " + snapshotId); + assertThat(query("SELECT * FROM " + tableName)) + .matches("VALUES 11, 22"); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testDropStatsSystemTable() + { + assertThatThrownBy(() -> query("ALTER TABLE \"nation$files\" EXECUTE DROP_EXTENDED_STATS")) + .hasMessage("This connector does not support table procedures"); + assertThatThrownBy(() -> query("ALTER TABLE \"nation$snapshots\" EXECUTE DROP_EXTENDED_STATS")) + .hasMessage("This connector does not support table procedures"); + } + + @Test + public void testAnalyzeAndRollbackToSnapshot() + { + String schema = getSession().getSchema().orElseThrow(); + String tableName = "test_analyze_and_rollback"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + long createSnapshot = getCurrentSnapshotId(tableName); + assertUpdate("ANALYZE " + tableName); + long analyzeSnapshot = getCurrentSnapshotId(tableName); + // ANALYZE currently does not create a new snapshot + assertEquals(analyzeSnapshot, createSnapshot); + + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM tpch.sf1.nation WHERE nationkey = 1", 1); + assertNotEquals(getCurrentSnapshotId(tableName), createSnapshot); + // NDV information present after INSERT + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 26, null, null)"""); + + assertUpdate(format("CALL system.rollback_to_snapshot('%s', '%s', %s)", schema, tableName, createSnapshot)); + // NDV information still present after rollback_to_snapshot + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testAnalyzeAndDeleteOrphanFiles() + { + String tableName = "test_analyze_and_delete_orphan_files"; + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.sf1.nation", 25); + assertUpdate("ANALYZE " + tableName); + + assertQuerySucceeds( + Session.builder(getSession()) + .setCatalogSessionProperty("iceberg", "remove_orphan_files_min_retention", "0s") + .build(), + "ALTER TABLE " + tableName + " EXECUTE REMOVE_ORPHAN_FILES (retention_threshold => '0s')"); + // NDV information still present + assertQuery( + "SHOW STATS FOR " + tableName, + """ + VALUES + ('nationkey', null, 25, 0, null, '0', '24'), + ('regionkey', null, 5, 0, null, '0', '4'), + ('comment', null, 25, 0, null, null, null), + ('name', null, 25, 0, null, null, null), + (null, null, null, null, 25, null, null)"""); + + assertUpdate("DROP TABLE " + tableName); + } + + private long getCurrentSnapshotId(String tableName) + { + return (long) computeActual(format("SELECT snapshot_id FROM \"%s$snapshots\" ORDER BY committed_at DESC FETCH FIRST 1 ROW WITH TIES", tableName)) + .getOnlyValue(); + } +} diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java index ea00782fbcf2..d83a158e99bb 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java @@ -49,6 +49,7 @@ public void testDefaults() .setCatalogType(HIVE_METASTORE) .setDynamicFilteringWaitTimeout(new Duration(0, MINUTES)) .setTableStatisticsEnabled(true) + .setExtendedStatisticsEnabled(false) .setProjectionPushdownEnabled(true) .setHiveCatalogName(null) .setFormatVersion(2) @@ -73,6 +74,7 @@ public void testExplicitPropertyMappings() .put("iceberg.catalog.type", "GLUE") .put("iceberg.dynamic-filtering.wait-timeout", "1h") .put("iceberg.table-statistics-enabled", "false") + .put("iceberg.experimental.extended-statistics.enabled", "true") .put("iceberg.projection-pushdown-enabled", "false") .put("iceberg.hive-catalog-name", "hive") .put("iceberg.format-version", "1") @@ -94,6 +96,7 @@ public void testExplicitPropertyMappings() .setCatalogType(GLUE) .setDynamicFilteringWaitTimeout(Duration.valueOf("1h")) .setTableStatisticsEnabled(false) + .setExtendedStatisticsEnabled(true) .setProjectionPushdownEnabled(false) .setHiveCatalogName("hive") .setFormatVersion(1) diff --git a/plugin/trino-jmx/pom.xml b/plugin/trino-jmx/pom.xml index 4c2b96668ea5..5f0c01fcdd85 100644 --- a/plugin/trino-jmx/pom.xml +++ b/plugin/trino-jmx/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java b/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java index a4fe0b656409..9029216786c1 100644 --- a/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java +++ b/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java @@ -204,7 +204,7 @@ public List listTables(ConnectorSession session, Optional new SchemaTableName(JmxMetadata.HISTORY_SCHEMA_NAME, tableName)) .collect(toList()); diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index bb827ff05794..27eddeb883f0 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java index 098b1fecd314..aed71a9f4338 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java @@ -248,27 +248,23 @@ public static Set filterValuesByDomain(Domain domain, Set sourceValu long singleValue = (long) domain.getSingleValue(); return sourceValues.stream().filter(sourceValue -> sourceValue == singleValue).collect(toImmutableSet()); } - else { - ValueSet valueSet = domain.getValues(); - if (valueSet instanceof SortedRangeSet) { - Ranges ranges = ((SortedRangeSet) valueSet).getRanges(); - List rangeList = ranges.getOrderedRanges(); - if (rangeList.stream().allMatch(io.trino.spi.predicate.Range::isSingleValue)) { - return rangeList.stream() - .map(range -> (Long) range.getSingleValue()) - .filter(sourceValues::contains) - .collect(toImmutableSet()); - } - else { - // still return values for range case like (_partition_id > 1) - io.trino.spi.predicate.Range span = ranges.getSpan(); - long low = getLowIncludedValue(span).orElse(0L); - long high = getHighIncludedValue(span).orElse(Long.MAX_VALUE); - return sourceValues.stream() - .filter(item -> item >= low && item <= high) - .collect(toImmutableSet()); - } + ValueSet valueSet = domain.getValues(); + if (valueSet instanceof SortedRangeSet) { + Ranges ranges = ((SortedRangeSet) valueSet).getRanges(); + List rangeList = ranges.getOrderedRanges(); + if (rangeList.stream().allMatch(io.trino.spi.predicate.Range::isSingleValue)) { + return rangeList.stream() + .map(range -> (Long) range.getSingleValue()) + .filter(sourceValues::contains) + .collect(toImmutableSet()); } + // still return values for range case like (_partition_id > 1) + io.trino.spi.predicate.Range span = ranges.getSpan(); + long low = getLowIncludedValue(span).orElse(0L); + long high = getHighIncludedValue(span).orElse(Long.MAX_VALUE); + return sourceValues.stream() + .filter(item -> item >= low && item <= high) + .collect(toImmutableSet()); } return sourceValues; } diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/AvroSchemaConverter.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/AvroSchemaConverter.java index 7729187a1c09..9752939fea91 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/AvroSchemaConverter.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/AvroSchemaConverter.java @@ -159,19 +159,19 @@ private Optional convertUnion(Schema schema) .filter(type -> type.getType() != NULL) .collect(toImmutableList()))); } - else if (schema.getTypes().size() == 1) { + if (schema.getTypes().size() == 1) { return convert(getOnlyElement(schema.getTypes())); } - else if (INTEGRAL_TYPES.containsAll(types)) { + if (INTEGRAL_TYPES.containsAll(types)) { return Optional.of(BigintType.BIGINT); } - else if (DECIMAL_TYPES.containsAll(types)) { + if (DECIMAL_TYPES.containsAll(types)) { return Optional.of(DoubleType.DOUBLE); } - else if (STRING_TYPES.containsAll(types)) { + if (STRING_TYPES.containsAll(types)) { return Optional.of(VarcharType.VARCHAR); } - else if (BINARY_TYPES.containsAll(types)) { + if (BINARY_TYPES.containsAll(types)) { return Optional.of(VarbinaryType.VARBINARY); } throw new UnsupportedOperationException(format("Incompatible UNION type: '%s'", schema.toString(true))); diff --git a/plugin/trino-kinesis/pom.xml b/plugin/trino-kinesis/pom.xml index ccc4d4314b47..ae550eda88c7 100644 --- a/plugin/trino-kinesis/pom.xml +++ b/plugin/trino-kinesis/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisRecordSet.java b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisRecordSet.java index 9aa5060ae913..b4500b6fae73 100644 --- a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisRecordSet.java +++ b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisRecordSet.java @@ -241,15 +241,13 @@ public boolean advanceNextPosition() if (listIterator.hasNext()) { return nextRow(); } - else { - log.debug("(%s:%s) Read all of the records from the shard: %d batches and %d messages and %d total bytes.", - split.getStreamName(), - split.getShardId(), - batchesRead, - totalMessages, - totalBytes); - return false; - } + log.debug("(%s:%s) Read all of the records from the shard: %d batches and %d messages and %d total bytes.", + split.getStreamName(), + split.getShardId(), + batchesRead, + totalMessages, + totalBytes); + return false; } private boolean shouldGetMoreRecords() diff --git a/plugin/trino-kinesis/src/test/java/io/trino/plugin/kinesis/util/MockKinesisClient.java b/plugin/trino-kinesis/src/test/java/io/trino/plugin/kinesis/util/MockKinesisClient.java index e53ef407d41c..8d762746ed96 100644 --- a/plugin/trino-kinesis/src/test/java/io/trino/plugin/kinesis/util/MockKinesisClient.java +++ b/plugin/trino-kinesis/src/test/java/io/trino/plugin/kinesis/util/MockKinesisClient.java @@ -169,9 +169,7 @@ public List getShardsFrom(String afterShardId) return returnArray; } - else { - return new ArrayList<>(); - } + return new ArrayList<>(); } public PutRecordResult putRecord(ByteBuffer data, String partitionKey) @@ -289,9 +287,7 @@ public PutRecordResult putRecord(PutRecordRequest putRecordRequest) if (theStream != null) { return theStream.putRecord(putRecordRequest.getData(), putRecordRequest.getPartitionKey()); } - else { - throw new AmazonClientException("This stream does not exist!"); - } + throw new AmazonClientException("This stream does not exist!"); } @Override @@ -328,9 +324,7 @@ public PutRecordsResult putRecords(PutRecordsRequest putRecordsRequest) result.setRecords(resultList); return result; } - else { - throw new AmazonClientException("This stream does not exist!"); - } + throw new AmazonClientException("This stream does not exist!"); } @Override diff --git a/plugin/trino-kudu/pom.xml b/plugin/trino-kudu/pom.xml index da3b3825b96c..4c44c1bb47f3 100644 --- a/plugin/trino-kudu/pom.xml +++ b/plugin/trino-kudu/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java index 1603a233bc3d..21a81ad4ec1e 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java @@ -342,23 +342,21 @@ private static RangeBoundValue buildRangePartitionBound(KuduTable table, byte[] if (rangeKey.length == 0) { return null; } - else { - Schema schema = table.getSchema(); - PartitionSchema partitionSchema = table.getPartitionSchema(); - PartitionSchema.RangeSchema rangeSchema = partitionSchema.getRangeSchema(); - List rangeColumns = rangeSchema.getColumnIds(); + Schema schema = table.getSchema(); + PartitionSchema partitionSchema = table.getPartitionSchema(); + PartitionSchema.RangeSchema rangeSchema = partitionSchema.getRangeSchema(); + List rangeColumns = rangeSchema.getColumnIds(); - int numColumns = rangeColumns.size(); + int numColumns = rangeColumns.size(); - PartialRow bound = KeyEncoderAccessor.decodeRangePartitionKey(schema, partitionSchema, rangeKey); + PartialRow bound = KeyEncoderAccessor.decodeRangePartitionKey(schema, partitionSchema, rangeKey); - ArrayList list = new ArrayList<>(); - for (int i = 0; i < numColumns; i++) { - Object obj = toValue(schema, bound, rangeColumns.get(i)); - list.add(obj); - } - return new RangeBoundValue(list); + ArrayList list = new ArrayList<>(); + for (int i = 0; i < numColumns; i++) { + Object obj = toValue(schema, bound, rangeColumns.get(i)); + list.add(obj); } + return new RangeBoundValue(list); } private static Object toValue(Schema schema, PartialRow bound, Integer idx) @@ -506,13 +504,11 @@ private static byte[] toByteArray(Object obj, Type type, String name) if (obj instanceof byte[]) { return (byte[]) obj; } - else if (obj instanceof String) { + if (obj instanceof String) { return Base64.getDecoder().decode((String) obj); } - else { - handleInvalidValue(name, type, obj); - return null; - } + handleInvalidValue(name, type, obj); + return null; } private static boolean toBoolean(Object obj, Type type, String name) @@ -520,13 +516,11 @@ private static boolean toBoolean(Object obj, Type type, String name) if (obj instanceof Boolean) { return (Boolean) obj; } - else if (obj instanceof String) { + if (obj instanceof String) { return Boolean.valueOf((String) obj); } - else { - handleInvalidValue(name, type, obj); - return false; - } + handleInvalidValue(name, type, obj); + return false; } private static long toUnixTimeMicros(Object obj, Type type, String name) @@ -534,16 +528,14 @@ private static long toUnixTimeMicros(Object obj, Type type, String name) if (Number.class.isAssignableFrom(obj.getClass())) { return ((Number) obj).longValue(); } - else if (obj instanceof String) { + if (obj instanceof String) { String s = (String) obj; s = s.trim().replace(' ', 'T'); long millis = ISODateTimeFormat.dateOptionalTimeParser().withZone(DateTimeZone.UTC).parseMillis(s); return millis * 1000; } - else { - handleInvalidValue(name, type, obj); - return 0; - } + handleInvalidValue(name, type, obj); + return 0; } private static Number toNumber(Object obj, Type type, String name) @@ -551,13 +543,11 @@ private static Number toNumber(Object obj, Type type, String name) if (Number.class.isAssignableFrom(obj.getClass())) { return (Number) obj; } - else if (obj instanceof String) { + if (obj instanceof String) { return new BigDecimal((String) obj); } - else { - handleInvalidValue(name, type, obj); - return 0; - } + handleInvalidValue(name, type, obj); + return 0; } private static void handleInvalidValue(String name, Type type, Object obj) diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueDeserializer.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueDeserializer.java index df7b9ad527c4..554f61a1f7e8 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueDeserializer.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueDeserializer.java @@ -36,22 +36,20 @@ public RangeBoundValue deserialize(JsonParser jp, DeserializationContext ctxt) if (node.isNull()) { return null; } - else { - List list; - if (node.isArray()) { - list = new ArrayList<>(); - Iterator iter = node.elements(); - while (iter.hasNext()) { - Object v = toValue(iter.next()); - list.add(v); - } - } - else { - Object v = toValue(node); - list = ImmutableList.of(v); + List list; + if (node.isArray()) { + list = new ArrayList<>(); + Iterator iter = node.elements(); + while (iter.hasNext()) { + Object v = toValue(iter.next()); + list.add(v); } - return new RangeBoundValue(list); } + else { + Object v = toValue(node); + list = ImmutableList.of(v); + } + return new RangeBoundValue(list); } private Object toValue(JsonNode node) @@ -60,17 +58,15 @@ private Object toValue(JsonNode node) if (node.isTextual()) { return node.asText(); } - else if (node.isNumber()) { + if (node.isNumber()) { return node.numberValue(); } - else if (node.isBoolean()) { + if (node.isBoolean()) { return node.asBoolean(); } - else if (node.isBinary()) { + if (node.isBinary()) { return node.binaryValue(); } - else { - throw new IllegalStateException("Unexpected range bound value: " + node); - } + throw new IllegalStateException("Unexpected range bound value: " + node); } } diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueSerializer.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueSerializer.java index 73a9dd902879..b1c484b30120 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueSerializer.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/RangeBoundValueSerializer.java @@ -51,7 +51,7 @@ private void writeValue(Object obj, JsonGenerator gen) if (obj == null) { throw new IllegalStateException("Unexpected null value"); } - else if (obj instanceof String) { + if (obj instanceof String) { gen.writeString((String) obj); } else if (Number.class.isAssignableFrom(obj.getClass())) { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/NoSchemaEmulation.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/NoSchemaEmulation.java index 692db5327d3f..b3e4b44f8b0d 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/NoSchemaEmulation.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/NoSchemaEmulation.java @@ -33,9 +33,7 @@ public void createSchema(KuduClientWrapper client, String schemaName) if (DEFAULT_SCHEMA.equals(schemaName)) { throw new SchemaAlreadyExistsException(schemaName); } - else { - throw new TrinoException(GENERIC_USER_ERROR, "Creating schema in Kudu connector not allowed if schema emulation is disabled."); - } + throw new TrinoException(GENERIC_USER_ERROR, "Creating schema in Kudu connector not allowed if schema emulation is disabled."); } @Override @@ -44,9 +42,7 @@ public void dropSchema(KuduClientWrapper client, String schemaName) if (DEFAULT_SCHEMA.equals(schemaName)) { throw new TrinoException(GENERIC_USER_ERROR, "Deleting default schema not allowed."); } - else { - throw new SchemaNotFoundException(schemaName); - } + throw new SchemaNotFoundException(schemaName); } @Override @@ -67,9 +63,7 @@ public String toRawName(SchemaTableName schemaTableName) if (DEFAULT_SCHEMA.equals(schemaTableName.getSchemaName())) { return schemaTableName.getTableName(); } - else { - throw new SchemaNotFoundException(schemaTableName.getSchemaName()); - } + throw new SchemaNotFoundException(schemaTableName.getSchemaName()); } @Override diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/SchemaEmulationByTableNameConvention.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/SchemaEmulationByTableNameConvention.java index 72d558adab2d..496e2a86205d 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/SchemaEmulationByTableNameConvention.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/schema/SchemaEmulationByTableNameConvention.java @@ -59,16 +59,14 @@ public void createSchema(KuduClientWrapper client, String schemaName) if (DEFAULT_SCHEMA.equals(schemaName)) { throw new SchemaAlreadyExistsException(schemaName); } - else { - try (KuduOperationApplier operationApplier = KuduOperationApplier.fromKuduClientWrapper(client)) { - KuduTable schemasTable = getSchemasTable(client); - Upsert upsert = schemasTable.newUpsert(); - upsert.getRow().addString(0, schemaName); - operationApplier.applyOperationAsync(upsert); - } - catch (KuduException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, e); - } + try (KuduOperationApplier operationApplier = KuduOperationApplier.fromKuduClientWrapper(client)) { + KuduTable schemasTable = getSchemasTable(client); + Upsert upsert = schemasTable.newUpsert(); + upsert.getRow().addString(0, schemaName); + operationApplier.applyOperationAsync(upsert); + } + catch (KuduException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); } } @@ -78,10 +76,8 @@ public boolean existsSchema(KuduClientWrapper client, String schemaName) if (DEFAULT_SCHEMA.equals(schemaName)) { return true; } - else { - List schemas = listSchemaNames(client); - return schemas.contains(schemaName); - } + List schemas = listSchemaNames(client); + return schemas.contains(schemaName); } @Override @@ -90,21 +86,19 @@ public void dropSchema(KuduClientWrapper client, String schemaName) if (DEFAULT_SCHEMA.equals(schemaName)) { throw new TrinoException(GENERIC_USER_ERROR, "Deleting default schema not allowed."); } - else { - try (KuduOperationApplier operationApplier = KuduOperationApplier.fromKuduClientWrapper(client)) { - String prefix = getPrefixForTablesOfSchema(schemaName); - for (String name : client.getTablesList(prefix).getTablesList()) { - client.deleteTable(name); - } - - KuduTable schemasTable = getSchemasTable(client); - Delete delete = schemasTable.newDelete(); - delete.getRow().addString(0, schemaName); - operationApplier.applyOperationAsync(delete); - } - catch (KuduException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + try (KuduOperationApplier operationApplier = KuduOperationApplier.fromKuduClientWrapper(client)) { + String prefix = getPrefixForTablesOfSchema(schemaName); + for (String name : client.getTablesList(prefix).getTablesList()) { + client.deleteTable(name); } + + KuduTable schemasTable = getSchemasTable(client); + Delete delete = schemasTable.newDelete(); + delete.getRow().addString(0, schemaName); + operationApplier.applyOperationAsync(delete); + } + catch (KuduException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); } } @@ -217,9 +211,7 @@ else if (schemaTableName.getSchemaName().indexOf('.') != -1) { if (DEFAULT_SCHEMA.equals(schemaTableName.getSchemaName())) { return schemaTableName.getTableName(); } - else { - return commonPrefix + schemaTableName.getSchemaName() + "." + schemaTableName.getTableName(); - } + return commonPrefix + schemaTableName.getSchemaName() + "." + schemaTableName.getTableName(); } @Override @@ -256,9 +248,7 @@ public String getPrefixForTablesOfSchema(String schemaName) if (DEFAULT_SCHEMA.equals(schemaName)) { return ""; } - else { - return commonPrefix + schemaName + "."; - } + return commonPrefix + schemaName + "."; } @Override diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java index a3166aee0922..a7128afc5b03 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java @@ -42,7 +42,7 @@ public class TestingKuduServer private static final Integer KUDU_TSERVER_PORT = 7050; private static final Integer NUMBER_OF_REPLICA = 3; - private static final String TOXIPROXY_IMAGE = "shopify/toxiproxy:2.1.4"; + private static final String TOXIPROXY_IMAGE = "ghcr.io/shopify/toxiproxy:2.4.0"; private static final String TOXIPROXY_NETWORK_ALIAS = "toxiproxy"; private final Network network; @@ -102,7 +102,7 @@ public TestingKuduServer(String kuduVersion) public HostAndPort getMasterAddress() { - // Do not use master.getContainerIpAddress(), it returns "localhost" which the kudu client resolves to: + // Do not use master.getHost(), it returns "localhost" which the kudu client resolves to: // localhost/127.0.0.1, localhost/0:0:0:0:0:0:0:1 // Instead explicitly list only the ipv4 loopback address 127.0.0.1 return HostAndPort.fromParts("127.0.0.1", master.getMappedPort(KUDU_MASTER_PORT)); diff --git a/plugin/trino-local-file/pom.xml b/plugin/trino-local-file/pom.xml index f917d09d5797..b6509d6ab568 100644 --- a/plugin/trino-local-file/pom.xml +++ b/plugin/trino-local-file/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileRecordCursor.java b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileRecordCursor.java index 49195ce9b73b..6c22c1597ba6 100644 --- a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileRecordCursor.java +++ b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileRecordCursor.java @@ -198,10 +198,8 @@ public long getLong(int field) if (getType(field).equals(createTimestampWithTimeZoneType(3))) { return parseTimestamp(getFieldValue(field)); } - else { - checkFieldType(field, BIGINT, INTEGER); - return Long.parseLong(getFieldValue(field)); - } + checkFieldType(field, BIGINT, INTEGER); + return Long.parseLong(getFieldValue(field)); } @Override diff --git a/plugin/trino-mariadb/pom.xml b/plugin/trino-mariadb/pom.xml index 3431196b8626..4ccfe5506f55 100644 --- a/plugin/trino-mariadb/pom.xml +++ b/plugin/trino-mariadb/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java index ef9dbcbb4ccf..a88e94a002e5 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java @@ -76,7 +76,7 @@ public String getPassword() public String getJdbcUrl() { - return format("jdbc:mariadb://%s:%s", container.getContainerIpAddress(), container.getMappedPort(MARIADB_PORT)); + return format("jdbc:mariadb://%s:%s", container.getHost(), container.getMappedPort(MARIADB_PORT)); } @Override diff --git a/plugin/trino-memory/pom.xml b/plugin/trino-memory/pom.xml index 284766c7fe9c..e5940639e984 100644 --- a/plugin/trino-memory/pom.xml +++ b/plugin/trino-memory/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-ml/pom.xml b/plugin/trino-ml/pom.xml index c1c915f73bda..265daae945a2 100644 --- a/plugin/trino-ml/pom.xml +++ b/plugin/trino-ml/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mongodb/pom.xml b/plugin/trino-mongodb/pom.xml index 6eb6eb832a18..3efdfe141272 100644 --- a/plugin/trino-mongodb/pom.xml +++ b/plugin/trino-mongodb/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java index 30d0d16e5788..19040251cda4 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java @@ -313,7 +313,7 @@ else if (isMapType(type)) { output.closeEntry(); return; } - else if (value instanceof Map) { + if (value instanceof Map) { BlockBuilder builder = output.beginBlockEntry(); Map document = (Map) value; for (Map.Entry entry : document.entrySet()) { @@ -341,7 +341,7 @@ else if (isRowType(type)) { output.closeEntry(); return; } - else if (value instanceof DBRef) { + if (value instanceof DBRef) { DBRef dbRefValue = (DBRef) value; BlockBuilder builder = output.beginBlockEntry(); @@ -353,7 +353,7 @@ else if (value instanceof DBRef) { output.closeEntry(); return; } - else if (value instanceof List) { + if (value instanceof List) { List listValue = (List) value; BlockBuilder builder = output.beginBlockEntry(); for (int index = 0; index < type.getTypeParameters().size(); index++) { diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index 122460f54298..f43530a1a14c 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -557,17 +557,15 @@ private Document getTableMetadata(String schemaName, String tableName) if (!collectionExists(db, tableName)) { throw new TableNotFoundException(new SchemaTableName(schemaName, tableName), format("Table '%s.%s' not found", schemaName, tableName), null); } - else { - Document metadata = new Document(TABLE_NAME_KEY, tableName); - metadata.append(FIELDS_KEY, guessTableFields(schemaName, tableName)); - if (!indexExists(schema)) { - schema.createIndex(new Document(TABLE_NAME_KEY, 1), new IndexOptions().unique(true)); - } + Document metadata = new Document(TABLE_NAME_KEY, tableName); + metadata.append(FIELDS_KEY, guessTableFields(schemaName, tableName)); + if (!indexExists(schema)) { + schema.createIndex(new Document(TABLE_NAME_KEY, 1), new IndexOptions().unique(true)); + } - schema.insertOne(metadata); + schema.insertOne(metadata); - return metadata; - } + return metadata; } return doc; diff --git a/plugin/trino-mysql/pom.xml b/plugin/trino-mysql/pom.xml index 1a557eecedcf..2b4d9efce3c0 100644 --- a/plugin/trino-mysql/pom.xml +++ b/plugin/trino-mysql/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java index 5d464ebc3704..973f86c9b8db 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java @@ -104,7 +104,7 @@ public String getDatabaseName() public String getJdbcUrl() { - return format("jdbc:mysql://%s:%s?useSSL=false&allowPublicKeyRetrieval=true", container.getContainerIpAddress(), container.getMappedPort(MYSQL_PORT)); + return format("jdbc:mysql://%s:%s?useSSL=false&allowPublicKeyRetrieval=true", container.getHost(), container.getMappedPort(MYSQL_PORT)); } @Override diff --git a/plugin/trino-oracle/pom.xml b/plugin/trino-oracle/pom.xml index 0f72378d196d..70c781f8de72 100644 --- a/plugin/trino-oracle/pom.xml +++ b/plugin/trino-oracle/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-password-authenticators/pom.xml b/plugin/trino-password-authenticators/pom.xml index 97abb542a645..c6a08755147d 100644 --- a/plugin/trino-password-authenticators/pom.xml +++ b/plugin/trino-password-authenticators/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestLdapAuthenticatorWithTimeouts.java b/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestLdapAuthenticatorWithTimeouts.java index 56be4cf7c2e3..4c9df0e33b3c 100644 --- a/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestLdapAuthenticatorWithTimeouts.java +++ b/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestLdapAuthenticatorWithTimeouts.java @@ -47,7 +47,7 @@ public void setup() Network network = Network.newNetwork(); closer.register(network::close); - ToxiproxyContainer proxyServer = new ToxiproxyContainer("shopify/toxiproxy:2.1.0") + ToxiproxyContainer proxyServer = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.4.0") .withNetwork(network); closer.register(proxyServer::close); proxyServer.start(); diff --git a/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestingOpenLdapServer.java b/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestingOpenLdapServer.java index 6a0955adc7ae..90104c9601d6 100644 --- a/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestingOpenLdapServer.java +++ b/plugin/trino-password-authenticators/src/test/java/io/trino/plugin/password/ldap/TestingOpenLdapServer.java @@ -76,7 +76,7 @@ public String getNetworkAlias() public String getLdapUrl() { - return format("ldap://%s:%s", openLdapServer.getContainerIpAddress(), openLdapServer.getMappedPort(LDAP_PORT)); + return format("ldap://%s:%s", openLdapServer.getHost(), openLdapServer.getMappedPort(LDAP_PORT)); } public DisposableSubContext createOrganization() diff --git a/plugin/trino-phoenix5/pom.xml b/plugin/trino-phoenix5/pom.xml index 8f8096b1d6b4..2f7eb968648f 100644 --- a/plugin/trino-phoenix5/pom.xml +++ b/plugin/trino-phoenix5/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java index 2ea59b77ee9b..895502f609a4 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java @@ -205,6 +205,9 @@ public class PhoenixClient private static final String DATE_FORMAT = "y-MM-dd G"; private static final DateTimeFormatter LOCAL_DATE_FORMATTER = DateTimeFormatter.ofPattern(DATE_FORMAT); + // Phoenix threshold for simplifying big IN predicates is 50k https://issues.apache.org/jira/browse/PHOENIX-6751 + public static final int PHOENIX_MAX_LIST_EXPRESSIONS = 5_000; + private final Configuration configuration; @Inject diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java index d9119b2e52e2..5f6a32a161d4 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java @@ -77,6 +77,7 @@ import static io.trino.plugin.jdbc.JdbcModule.bindSessionPropertiesProvider; import static io.trino.plugin.jdbc.JdbcModule.bindTablePropertiesProvider; import static io.trino.plugin.phoenix5.ConfigurationInstantiator.newEmptyConfiguration; +import static io.trino.plugin.phoenix5.PhoenixClient.PHOENIX_MAX_LIST_EXPRESSIONS; import static io.trino.plugin.phoenix5.PhoenixErrorCode.PHOENIX_CONFIG_ERROR; import static java.util.Objects.requireNonNull; import static org.weakref.jmx.guice.ExportBinder.newExporter; @@ -102,7 +103,7 @@ protected void setup(Binder binder) binder.bind(ConnectorPageSinkProvider.class).annotatedWith(ForClassLoaderSafe.class).to(JdbcPageSinkProvider.class).in(Scopes.SINGLETON); binder.bind(ConnectorPageSinkProvider.class).to(ClassLoaderSafeConnectorPageSinkProvider.class).in(Scopes.SINGLETON); binder.bind(QueryBuilder.class).to(DefaultQueryBuilder.class).in(Scopes.SINGLETON); - newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)); + newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(PHOENIX_MAX_LIST_EXPRESSIONS); configBinder(binder).bindConfig(TypeHandlingJdbcConfig.class); bindSessionPropertiesProvider(binder, TypeHandlingJdbcSessionProperties.class); diff --git a/plugin/trino-pinot/pom.xml b/plugin/trino-pinot/pom.xml index acaf85d7c34d..c07200b7d192 100755 --- a/plugin/trino-pinot/pom.xml +++ b/plugin/trino-pinot/pom.xml @@ -4,7 +4,7 @@ trino-root io.trino - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml @@ -546,6 +546,18 @@ + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + io.airlift testing diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java index 7e35f909bc9e..18e7e472f4d9 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java @@ -266,9 +266,7 @@ private double getDouble(int rowIndex, int columnIndex) if (dataType.equals(ColumnDataType.FLOAT)) { return currentDataTable.getDataTable().getFloat(rowIndex, columnIndex); } - else { - return currentDataTable.getDataTable().getDouble(rowIndex, columnIndex); - } + return currentDataTable.getDataTable().getDouble(rowIndex, columnIndex); } private Block getArrayBlock(int rowIndex, int columnIndex) @@ -327,10 +325,10 @@ private Slice getSlice(int rowIndex, int columnIndex) String field = currentDataTable.getDataTable().getString(rowIndex, columnIndex); return getUtf8Slice(field); } - else if (trinoType instanceof VarbinaryType) { + if (trinoType instanceof VarbinaryType) { return Slices.wrappedBuffer(toBytes(currentDataTable.getDataTable().getString(rowIndex, columnIndex))); } - else if (trinoType.getTypeSignature().getBase() == StandardTypes.JSON) { + if (trinoType.getTypeSignature().getBase() == StandardTypes.JSON) { String field = currentDataTable.getDataTable().getString(rowIndex, columnIndex); return jsonParse(getUtf8Slice(field)); } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSplitManager.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSplitManager.java index be2fbef644bb..3e5aeb2cb2dd 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSplitManager.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSplitManager.java @@ -174,9 +174,7 @@ public ConnectorSplitSource getSplits( } return generateSplitsForSegmentBasedScan(pinotTableHandle, session); } - else { - return generateSplitForBrokerBasedScan(pinotTableHandle); - } + return generateSplitForBrokerBasedScan(pinotTableHandle); } private static boolean isBrokerQuery(ConnectorSession session, PinotTableHandle tableHandle) diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java index ae8a1001ac0d..9f6465dd2254 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/client/PinotClient.java @@ -364,12 +364,10 @@ public List getAllBrokersForTable(String table) if (matcher.matches() && matcher.groupCount() == 2) { return pinotHostMapper.getBrokerHost(matcher.group(1), matcher.group(2)); } - else { - throw new PinotException( - PINOT_UNABLE_TO_FIND_BROKER, - Optional.empty(), - format("Cannot parse %s in the broker instance", brokerToParse)); - } + throw new PinotException( + PINOT_UNABLE_TO_FIND_BROKER, + Optional.empty(), + format("Cannot parse %s in the broker instance", brokerToParse)); }) .collect(Collectors.toCollection(ArrayList::new)); Collections.shuffle(brokers); @@ -390,9 +388,7 @@ public String getBrokerHost(String table) if (throwable instanceof PinotException) { throw (PinotException) throwable; } - else { - throw new PinotException(PINOT_UNABLE_TO_FIND_BROKER, Optional.empty(), "Error when getting brokers for table " + table, throwable); - } + throw new PinotException(PINOT_UNABLE_TO_FIND_BROKER, Optional.empty(), "Error when getting brokers for table " + table, throwable); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/DecoderFactory.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/DecoderFactory.java index ffc72dd8d668..da0cbf0bc347 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/DecoderFactory.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/decoders/DecoderFactory.java @@ -43,33 +43,29 @@ public static Decoder createDecoder(Type type) if (type instanceof DoubleType) { return new DoubleDecoder(); } - else if (type instanceof RealType) { + if (type instanceof RealType) { return new RealDecoder(); } - else if (type instanceof BigintType) { + if (type instanceof BigintType) { return new BigintDecoder(); } - else if (type instanceof IntegerType) { + if (type instanceof IntegerType) { return new IntegerDecoder(); } - else if (type instanceof BooleanType) { + if (type instanceof BooleanType) { return new BooleanDecoder(); } - else { - throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "type '" + type + "' not supported"); - } + throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "type '" + type + "' not supported"); } - else if (type instanceof ArrayType) { + if (type instanceof ArrayType) { return new ArrayDecoder(type); } - else if (type instanceof VarbinaryType) { + if (type instanceof VarbinaryType) { return new VarbinaryDecoder(); } - else if (type.getTypeSignature().getBase().equals(JSON)) { + if (type.getTypeSignature().getBase().equals(JSON)) { return new JsonDecoder(); } - else { - return new VarcharDecoder(); - } + return new VarcharDecoder(); } } diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java index 2ea712674e64..d79a778b9b1a 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTableBuilder.java @@ -268,9 +268,7 @@ private static OptionalLong getOffset(QueryContext queryContext) if (queryContext.getOffset() > 0) { return OptionalLong.of(queryContext.getOffset()); } - else { - return OptionalLong.empty(); - } + return OptionalLong.empty(); } private static String stripSuffix(String tableName) @@ -279,12 +277,10 @@ private static String stripSuffix(String tableName) if (tableName.toUpperCase(ENGLISH).endsWith(OFFLINE_SUFFIX)) { return tableName.substring(0, tableName.length() - OFFLINE_SUFFIX.length()); } - else if (tableName.toUpperCase(ENGLISH).endsWith(REALTIME_SUFFIX)) { + if (tableName.toUpperCase(ENGLISH).endsWith(REALTIME_SUFFIX)) { return tableName.substring(0, tableName.length() - REALTIME_SUFFIX.length()); } - else { - return tableName; - } + return tableName; } private static Optional getSuffix(String tableName) @@ -293,12 +289,10 @@ private static Optional getSuffix(String tableName) if (tableName.toUpperCase(ENGLISH).endsWith(OFFLINE_SUFFIX)) { return Optional.of(OFFLINE_SUFFIX); } - else if (tableName.toUpperCase(ENGLISH).endsWith(REALTIME_SUFFIX)) { + if (tableName.toUpperCase(ENGLISH).endsWith(REALTIME_SUFFIX)) { return Optional.of(REALTIME_SUFFIX); } - else { - return Optional.empty(); - } + return Optional.empty(); } private static class PinotColumnNameAndTrinoType diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java index 229443ad1f04..40ff0c5dccaf 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/DynamicTablePqlExtractor.java @@ -93,15 +93,13 @@ private static Optional getFilter(Optional filter, TupleDomain binaryFunction() if (predicate.getType() == Predicate.Type.IN) { return Optional.of(((InPredicate) predicate).getValues()); } - else if (predicate.getType() == Predicate.Type.NOT_IN) { + if (predicate.getType() == Predicate.Type.NOT_IN) { return Optional.of(((NotInPredicate) predicate).getValues()); } return Optional.empty(); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java index 012feaf2f219..a4415db0d10a 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/query/PinotQueryBuilder.java @@ -120,9 +120,7 @@ public static Optional getFilterClause(TupleDomain tupleDo if (!conjuncts.isEmpty()) { return Optional.of(Joiner.on(" AND ").join(conjuncts)); } - else { - return Optional.empty(); - } + return Optional.empty(); } private static String toPredicate(PinotColumnHandle pinotColumnHandle, Domain domain) @@ -163,10 +161,10 @@ private static Object convertValue(Type type, Object value) if (type instanceof RealType) { return intBitsToFloat(toIntExact((Long) value)); } - else if (type instanceof VarcharType) { + if (type instanceof VarcharType) { return ((Slice) value).toStringUtf8(); } - else if (type instanceof VarbinaryType) { + if (type instanceof VarbinaryType) { return Hex.encodeHexString(((Slice) value).getBytes()); } return value; diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/AbstractPinotIntegrationSmokeTest.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/BasePinotIntegrationConnectorSmokeTest.java similarity index 93% rename from plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/AbstractPinotIntegrationSmokeTest.java rename to plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/BasePinotIntegrationConnectorSmokeTest.java index cffea1e42a6f..ed109cc6ace2 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/AbstractPinotIntegrationSmokeTest.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/BasePinotIntegrationConnectorSmokeTest.java @@ -21,15 +21,19 @@ import io.confluent.kafka.serializers.KafkaAvroSerializer; import io.trino.Session; import io.trino.plugin.pinot.client.PinotHostMapper; +import io.trino.plugin.tpch.TpchPlugin; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.ProjectNode; -import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.BaseConnectorSmokeTest; +import io.trino.testing.DistributedQueryRunner; import io.trino.testing.MaterializedResult; +import io.trino.testing.MaterializedRow; import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.kafka.TestingKafka; import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; @@ -76,6 +80,7 @@ import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.confluent.kafka.serializers.AbstractKafkaSchemaSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG; +import static io.trino.plugin.pinot.PinotQueryRunner.createPinotQueryRunner; import static io.trino.plugin.pinot.TestingPinotCluster.PINOT_PREVIOUS_IMAGE_NAME; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.RealType.REAL; @@ -93,9 +98,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; -public abstract class AbstractPinotIntegrationSmokeTest - // TODO extend BaseConnectorTest - extends AbstractTestQueryFramework +public abstract class BasePinotIntegrationConnectorSmokeTest + extends BaseConnectorSmokeTest { private static final int MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES = 11; private static final int MAX_ROWS_PER_SPLIT_FOR_BROKER_QUERIES = 12; @@ -142,6 +146,38 @@ protected QueryRunner createQueryRunner() TestingPinotCluster pinot = closeAfterClass(new TestingPinotCluster(kafka.getNetwork(), isSecured(), getPinotImageName())); pinot.start(); + createAndPopulateAllTypesTopic(kafka, pinot); + createAndPopulateMixedCaseTableAndTopic(kafka, pinot); + createAndPopulateMixedCaseDistinctTableAndTopic(kafka, pinot); + createAndPopulateTooManyRowsTable(kafka, pinot); + createAndPopulateTooManyBrokerRowsTableAndTopic(kafka, pinot); + createTheDuplicateTablesAndTopics(kafka, pinot); + createAndPopulateDateTimeFieldsTableAndTopic(kafka, pinot); + createAndPopulateJsonTypeTable(kafka, pinot); + createAndPopulateJsonTable(kafka, pinot); + createAndPopulateMixedCaseHybridTablesAndTopic(kafka, pinot); + createAndPopulateTableHavingReservedKeywordColumnNames(kafka, pinot); + createAndPopulateHavingQuotesInColumnNames(kafka, pinot); + createAndPopulateHavingMultipleColumnsWithDuplicateValues(kafka, pinot); + + DistributedQueryRunner queryRunner = createPinotQueryRunner( + ImmutableMap.of(), + pinotProperties(pinot), + Optional.of(binder -> newOptionalBinder(binder, PinotHostMapper.class).setBinding() + .toInstance(new TestingPinotHostMapper(pinot.getBrokerHostAndPort(), pinot.getServerHostAndPort(), pinot.getServerGrpcHostAndPort())))); + + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + + // We need the query runner to populate nation and region data from tpch schema + createAndPopulateNationAndRegionData(kafka, pinot, queryRunner); + + return queryRunner; + } + + private void createAndPopulateAllTypesTopic(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate the all_types topic and table kafka.createTopic(ALL_TYPES_TABLE); @@ -165,7 +201,11 @@ protected QueryRunner createQueryRunner() pinot.createSchema(getClass().getClassLoader().getResourceAsStream("alltypes_schema.json"), ALL_TYPES_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("alltypes_realtimeSpec.json"), ALL_TYPES_TABLE); + } + private void createAndPopulateMixedCaseTableAndTopic(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate mixed case table and topic kafka.createTopic(MIXED_CASE_COLUMN_NAMES_TABLE); Schema mixedCaseAvroSchema = SchemaBuilder.record(MIXED_CASE_COLUMN_NAMES_TABLE).fields() @@ -200,7 +240,11 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(mixedCaseProducerRecords.stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("mixed_case_schema.json"), MIXED_CASE_COLUMN_NAMES_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("mixed_case_realtimeSpec.json"), MIXED_CASE_COLUMN_NAMES_TABLE); + } + private void createAndPopulateMixedCaseDistinctTableAndTopic(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate mixed case distinct table and topic kafka.createTopic(MIXED_CASE_DISTINCT_TABLE); Schema mixedCaseDistinctAvroSchema = SchemaBuilder.record(MIXED_CASE_DISTINCT_TABLE).fields() @@ -234,7 +278,11 @@ protected QueryRunner createQueryRunner() // Create mixed case table name, populated from the mixed case topic pinot.createSchema(getClass().getClassLoader().getResourceAsStream("mixed_case_table_name_schema.json"), MIXED_CASE_TABLE_NAME); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("mixed_case_table_name_realtimeSpec.json"), MIXED_CASE_TABLE_NAME); + } + private void createAndPopulateTooManyRowsTable(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate too many rows table and topic kafka.createTopic(TOO_MANY_ROWS_TABLE); Schema tooManyRowsAvroSchema = SchemaBuilder.record(TOO_MANY_ROWS_TABLE).fields() @@ -255,7 +303,11 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(tooManyRowsRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("too_many_rows_schema.json"), TOO_MANY_ROWS_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("too_many_rows_realtimeSpec.json"), TOO_MANY_ROWS_TABLE); + } + private void createAndPopulateTooManyBrokerRowsTableAndTopic(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate too many broker rows table and topic kafka.createTopic(TOO_MANY_BROKER_ROWS_TABLE); Schema tooManyBrokerRowsAvroSchema = SchemaBuilder.record(TOO_MANY_BROKER_ROWS_TABLE).fields() @@ -273,7 +325,11 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(tooManyBrokerRowsRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("too_many_broker_rows_schema.json"), TOO_MANY_BROKER_ROWS_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("too_many_broker_rows_realtimeSpec.json"), TOO_MANY_BROKER_ROWS_TABLE); + } + private void createTheDuplicateTablesAndTopics(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create the duplicate tables and topics kafka.createTopic(DUPLICATE_TABLE_LOWERCASE); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("dup_table_lower_case_schema.json"), DUPLICATE_TABLE_LOWERCASE); @@ -282,7 +338,11 @@ protected QueryRunner createQueryRunner() kafka.createTopic(DUPLICATE_TABLE_MIXED_CASE); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("dup_table_mixed_case_schema.json"), DUPLICATE_TABLE_MIXED_CASE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("dup_table_mixed_case_realtimeSpec.json"), DUPLICATE_TABLE_MIXED_CASE); + } + private void createAndPopulateDateTimeFieldsTableAndTopic(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate date time fields table and topic kafka.createTopic(DATE_TIME_FIELDS_TABLE); Schema dateTimeFieldsAvroSchema = SchemaBuilder.record(DATE_TIME_FIELDS_TABLE).fields() @@ -310,7 +370,11 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(dateTimeFieldsProducerRecords.stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("date_time_fields_schema.json"), DATE_TIME_FIELDS_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("date_time_fields_realtimeSpec.json"), DATE_TIME_FIELDS_TABLE); + } + private void createAndPopulateJsonTypeTable(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create json type table kafka.createTopic(JSON_TYPE_TABLE); @@ -332,7 +396,11 @@ protected QueryRunner createQueryRunner() pinot.createSchema(getClass().getClassLoader().getResourceAsStream("json_schema.json"), JSON_TYPE_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("json_realtimeSpec.json"), JSON_TYPE_TABLE); pinot.addOfflineTable(getClass().getClassLoader().getResourceAsStream("json_offlineSpec.json"), JSON_TYPE_TABLE); + } + private void createAndPopulateJsonTable(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create json table kafka.createTopic(JSON_TABLE); long key = 0L; @@ -347,7 +415,11 @@ protected QueryRunner createQueryRunner() pinot.createSchema(getClass().getClassLoader().getResourceAsStream("schema.json"), JSON_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("realtimeSpec.json"), JSON_TABLE); + } + private void createAndPopulateMixedCaseHybridTablesAndTopic(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create and populate mixed case table and topic kafka.createTopic(HYBRID_TABLE_NAME); Schema hybridAvroSchema = SchemaBuilder.record(HYBRID_TABLE_NAME).fields() @@ -433,7 +505,11 @@ protected QueryRunner createQueryRunner() } kafka.sendMessages(hybridProducerRecords.stream(), schemaRegistryAwareProducer(kafka)); + } + private void createAndPopulateTableHavingReservedKeywordColumnNames(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create a table having reserved keyword column names kafka.createTopic(RESERVED_KEYWORD_TABLE); Schema reservedKeywordAvroSchema = SchemaBuilder.record(RESERVED_KEYWORD_TABLE).fields() @@ -447,7 +523,11 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(reservedKeywordRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("reserved_keyword_schema.json"), RESERVED_KEYWORD_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("reserved_keyword_realtimeSpec.json"), RESERVED_KEYWORD_TABLE); + } + private void createAndPopulateHavingQuotesInColumnNames(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create a table having quotes in column names kafka.createTopic(QUOTES_IN_COLUMN_NAME_TABLE); Schema quotesInColumnNameAvroSchema = SchemaBuilder.record(QUOTES_IN_COLUMN_NAME_TABLE).fields() @@ -460,7 +540,11 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(quotesInColumnNameRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("quotes_in_column_name_schema.json"), QUOTES_IN_COLUMN_NAME_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("quotes_in_column_name_realtimeSpec.json"), QUOTES_IN_COLUMN_NAME_TABLE); + } + private void createAndPopulateHavingMultipleColumnsWithDuplicateValues(TestingKafka kafka, TestingPinotCluster pinot) + throws Exception + { // Create a table having multiple columns with duplicate values kafka.createTopic(DUPLICATE_VALUES_IN_COLUMNS_TABLE); Schema duplicateValuesInColumnsAvroSchema = SchemaBuilder.record(DUPLICATE_VALUES_IN_COLUMNS_TABLE).fields() @@ -523,12 +607,71 @@ protected QueryRunner createQueryRunner() kafka.sendMessages(duplicateValuesInColumnsRecordsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); pinot.createSchema(getClass().getClassLoader().getResourceAsStream("duplicate_values_in_columns_schema.json"), DUPLICATE_VALUES_IN_COLUMNS_TABLE); pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("duplicate_values_in_columns_realtimeSpec.json"), DUPLICATE_VALUES_IN_COLUMNS_TABLE); + } - return PinotQueryRunner.createPinotQueryRunner( - ImmutableMap.of(), - pinotProperties(pinot), - Optional.of(binder -> newOptionalBinder(binder, PinotHostMapper.class).setBinding() - .toInstance(new TestingPinotHostMapper(pinot.getBrokerHostAndPort(), pinot.getServerHostAndPort(), pinot.getServerGrpcHostAndPort())))); + private void createAndPopulateNationAndRegionData(TestingKafka kafka, TestingPinotCluster pinot, DistributedQueryRunner queryRunner) + throws Exception + { + // Create and populate table and topic data + String regionTableName = "region"; + kafka.createTopicWithConfig(2, 1, regionTableName, false); + Schema regionSchema = SchemaBuilder.record(regionTableName).fields() + // regionkey bigint, name varchar, comment varchar + .name("regionkey").type().longType().noDefault() + .name("name").type().stringType().noDefault() + .name("comment").type().stringType().noDefault() + .name("updated_at_seconds").type().longType().noDefault() + .endRecord(); + ImmutableList.Builder> regionRowsBuilder = ImmutableList.builder(); + MaterializedResult regionRows = queryRunner.execute("SELECT * FROM tpch.tiny.region"); + for (MaterializedRow row : regionRows.getMaterializedRows()) { + regionRowsBuilder.add(new ProducerRecord<>(regionTableName, "key" + row.getField(0), new GenericRecordBuilder(regionSchema) + .set("regionkey", row.getField(0)) + .set("name", row.getField(1)) + .set("comment", row.getField(2)) + .set("updated_at_seconds", initialUpdatedAt.plusMillis(1000).toEpochMilli()) + .build())); + } + kafka.sendMessages(regionRowsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); + pinot.createSchema(getClass().getClassLoader().getResourceAsStream("region_schema.json"), regionTableName); + pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("region_realtimeSpec.json"), regionTableName); + + String nationTableName = "nation"; + kafka.createTopicWithConfig(2, 1, nationTableName, false); + Schema nationSchema = SchemaBuilder.record(nationTableName).fields() + // nationkey BIGINT, name VARCHAR, VARCHAR, regionkey BIGINT + .name("nationkey").type().longType().noDefault() + .name("name").type().stringType().noDefault() + .name("comment").type().stringType().noDefault() + .name("regionkey").type().longType().noDefault() + .name("updated_at_seconds").type().longType().noDefault() + .endRecord(); + ImmutableList.Builder> nationRowsBuilder = ImmutableList.builder(); + MaterializedResult nationRows = queryRunner.execute("SELECT * FROM tpch.tiny.nation"); + for (MaterializedRow row : nationRows.getMaterializedRows()) { + nationRowsBuilder.add(new ProducerRecord<>(nationTableName, "key" + row.getField(0), new GenericRecordBuilder(nationSchema) + .set("nationkey", row.getField(0)) + .set("name", row.getField(1)) + .set("comment", row.getField(3)) + .set("regionkey", row.getField(2)) + .set("updated_at_seconds", initialUpdatedAt.plusMillis(1000).toEpochMilli()) + .build())); + } + kafka.sendMessages(nationRowsBuilder.build().stream(), schemaRegistryAwareProducer(kafka)); + pinot.createSchema(getClass().getClassLoader().getResourceAsStream("nation_schema.json"), nationTableName); + pinot.addRealTimeTable(getClass().getClassLoader().getResourceAsStream("nation_realtimeSpec.json"), nationTableName); + } + + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + return switch (connectorBehavior) { + case SUPPORTS_CREATE_SCHEMA, + SUPPORTS_CREATE_TABLE, + SUPPORTS_INSERT, + SUPPORTS_RENAME_TABLE -> false; + default -> super.hasBehavior(connectorBehavior); + }; } private Map pinotProperties(TestingPinotCluster pinot) @@ -833,6 +976,37 @@ public static Object of( } } + @Override + public void testShowCreateTable() + { + assertQueryFails("SHOW CREATE TABLE region", "No PropertyMetadata for property: pinotColumnName"); + } + + @Override + public void testSelectInformationSchemaColumns() + { + // Override because there's updated_at_seconds column + assertThat(query("SELECT column_name FROM information_schema.columns WHERE table_schema = 'default' AND table_name = 'region'")) + .skippingTypesCheck() + .matches("VALUES 'regionkey', 'name', 'comment', 'updated_at_seconds'"); + } + + @Override + public void testTopN() + { + // TODO https://github.com/trinodb/trino/issues/14045 Fix ORDER BY ... LIMIT query + assertQueryFails("SELECT regionkey FROM nation ORDER BY name LIMIT 3", + format("Segment query returned '%2$s' rows per split, maximum allowed is '%1$s' rows. with query \"SELECT \"regionkey\", \"name\" FROM nation_REALTIME LIMIT 12\"", MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES, MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES + 1)); + } + + @Override + public void testJoin() + { + // TODO https://github.com/trinodb/trino/issues/14046 Fix JOIN query + assertQueryFails("SELECT n.name, r.name FROM nation n JOIN region r on n.regionkey = r.regionkey", + format("Segment query returned '%2$s' rows per split, maximum allowed is '%1$s' rows. with query \"SELECT \"regionkey\", \"name\" FROM nation_REALTIME LIMIT 12\"", MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES, MAX_ROWS_PER_SPLIT_FOR_SEGMENT_QUERIES + 1)); + } + @Test public void testRealType() { @@ -1361,12 +1535,6 @@ public void testLimitPushdown() .isNotFullyPushedDown(LimitNode.class); } - @Test - public void testCreateTable() - { - assertQueryFails("CREATE TABLE test_create_table (col INT)", "This connector does not support creating tables"); - } - /** * https://github.com/trinodb/trino/issues/8307 */ diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/PinotQueryRunner.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/PinotQueryRunner.java index 235e55bd9bce..565561700b60 100755 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/PinotQueryRunner.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/PinotQueryRunner.java @@ -43,7 +43,6 @@ public static DistributedQueryRunner createPinotQueryRunner(Map throws Exception { DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(createSession("default")) - .setNodeCount(2) .setExtraProperties(extraProperties) .build(); queryRunner.installPlugin(new PinotPlugin(extension)); diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTest.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationConnectorConnectorSmokeTest.java similarity index 83% rename from plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTest.java rename to plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationConnectorConnectorSmokeTest.java index 54d1008c3475..424aa53ec07d 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTest.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationConnectorConnectorSmokeTest.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.pinot; -public class TestPinotWithoutAuthenticationIntegrationSmokeTest - extends AbstractPinotIntegrationSmokeTest +public class TestPinotWithoutAuthenticationIntegrationConnectorConnectorSmokeTest + extends BasePinotIntegrationConnectorSmokeTest { @Override protected boolean isSecured() diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersion.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationLatestVersionConnectorConnectorSmokeTest.java similarity index 85% rename from plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersion.java rename to plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationLatestVersionConnectorConnectorSmokeTest.java index 05b7c5d89e38..35795c92c566 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersion.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationLatestVersionConnectorConnectorSmokeTest.java @@ -15,8 +15,8 @@ import static io.trino.plugin.pinot.TestingPinotCluster.PINOT_LATEST_IMAGE_NAME; -public class TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersion - extends AbstractPinotIntegrationSmokeTest +public class TestPinotWithoutAuthenticationIntegrationLatestVersionConnectorConnectorSmokeTest + extends BasePinotIntegrationConnectorSmokeTest { @Override protected boolean isSecured() diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersionNoGrpc.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationLatestVersionNoGrpcConnectorSmokeTest.java similarity index 86% rename from plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersionNoGrpc.java rename to plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationLatestVersionNoGrpcConnectorSmokeTest.java index 979ed75d5c8b..274e6f2da1a0 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersionNoGrpc.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestPinotWithoutAuthenticationIntegrationLatestVersionNoGrpcConnectorSmokeTest.java @@ -15,8 +15,8 @@ import static io.trino.plugin.pinot.TestingPinotCluster.PINOT_LATEST_IMAGE_NAME; -public class TestPinotWithoutAuthenticationIntegrationSmokeTestLatestVersionNoGrpc - extends AbstractPinotIntegrationSmokeTest +public class TestPinotWithoutAuthenticationIntegrationLatestVersionNoGrpcConnectorSmokeTest + extends BasePinotIntegrationConnectorSmokeTest { @Override protected boolean isSecured() diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestSecuredPinotIntegrationSmokeTest.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestSecuredPinotIntegrationConnectorSmokeTest.java similarity index 92% rename from plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestSecuredPinotIntegrationSmokeTest.java rename to plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestSecuredPinotIntegrationConnectorSmokeTest.java index 189b6236a327..4cd780e1b046 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestSecuredPinotIntegrationSmokeTest.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestSecuredPinotIntegrationConnectorSmokeTest.java @@ -19,8 +19,8 @@ import static io.trino.plugin.pinot.auth.PinotAuthenticationType.PASSWORD; -public class TestSecuredPinotIntegrationSmokeTest - extends AbstractPinotIntegrationSmokeTest +public class TestSecuredPinotIntegrationConnectorSmokeTest + extends BasePinotIntegrationConnectorSmokeTest { @Override protected boolean isSecured() diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java index a98b87b3cd72..4732955ee2aa 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestingPinotCluster.java @@ -152,22 +152,22 @@ private static String getZookeeperInternalHostPort() public String getControllerConnectString() { - return controller.getContainerIpAddress() + ":" + controller.getMappedPort(CONTROLLER_PORT); + return controller.getHost() + ":" + controller.getMappedPort(CONTROLLER_PORT); } public HostAndPort getBrokerHostAndPort() { - return HostAndPort.fromParts(broker.getContainerIpAddress(), broker.getMappedPort(BROKER_PORT)); + return HostAndPort.fromParts(broker.getHost(), broker.getMappedPort(BROKER_PORT)); } public HostAndPort getServerHostAndPort() { - return HostAndPort.fromParts(server.getContainerIpAddress(), server.getMappedPort(SERVER_PORT)); + return HostAndPort.fromParts(server.getHost(), server.getMappedPort(SERVER_PORT)); } public HostAndPort getServerGrpcHostAndPort() { - return HostAndPort.fromParts(server.getContainerIpAddress(), server.getMappedPort(GRPC_PORT)); + return HostAndPort.fromParts(server.getHost(), server.getMappedPort(GRPC_PORT)); } public void createSchema(InputStream tableSchemaSpec, String tableName) @@ -271,9 +271,7 @@ public void publishOfflineSegment(String tableName, Path segmentPath) if (statusCode >= 500) { return false; } - else { - throw e; - } + throw e; } }); } diff --git a/plugin/trino-pinot/src/test/resources/nation_realtimeSpec.json b/plugin/trino-pinot/src/test/resources/nation_realtimeSpec.json new file mode 100644 index 000000000000..9cce19a01f8c --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/nation_realtimeSpec.json @@ -0,0 +1,46 @@ +{ + "tableName": "nation", + "tableType": "REALTIME", + "segmentsConfig": { + "timeColumnName": "updated_at_seconds", + "retentionTimeUnit": "DAYS", + "retentionTimeValue": "365", + "segmentPushType": "APPEND", + "segmentPushFrequency": "daily", + "segmentAssignmentStrategy": "BalanceNumSegmentAssignmentStrategy", + "schemaName": "nation", + "replicasPerPartition": "1" + }, + "tenants": { + "broker": "DefaultTenant", + "server": "DefaultTenant" + }, + "tableIndexConfig": { + "loadMode": "MMAP", + "noDictionaryColumns": [], + "sortedColumn": [ + "updated_at_seconds" + ], + "aggregateMetrics": "false", + "nullHandlingEnabled": "true", + "streamConfigs": { + "streamType": "kafka", + "stream.kafka.consumer.type": "lowLevel", + "stream.kafka.topic.name": "nation", + "stream.kafka.decoder.class.name": "org.apache.pinot.plugin.inputformat.avro.confluent.KafkaConfluentSchemaRegistryAvroMessageDecoder", + "stream.kafka.consumer.factory.class.name": "org.apache.pinot.plugin.stream.kafka20.KafkaConsumerFactory", + "stream.kafka.decoder.prop.schema.registry.rest.url": "http://schema-registry:8081", + "stream.kafka.zk.broker.url": "zookeeper:2181/", + "stream.kafka.broker.list": "kafka:9092", + "realtime.segment.flush.threshold.time": "1m", + "realtime.segment.flush.threshold.size": "0", + "realtime.segment.flush.desired.size": "1M", + "isolation.level": "read_committed", + "stream.kafka.consumer.prop.auto.offset.reset": "smallest", + "stream.kafka.consumer.prop.group.id": "pinot_nation" + } + }, + "metadata": { + "customConfigs": {} + } +} diff --git a/plugin/trino-pinot/src/test/resources/nation_schema.json b/plugin/trino-pinot/src/test/resources/nation_schema.json new file mode 100644 index 000000000000..b6c2922d8f07 --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/nation_schema.json @@ -0,0 +1,31 @@ +{ + "schemaName": "nation", + "dimensionFieldSpecs": [ + { + "name": "nationkey", + "dataType": "LONG" + }, + { + "name": "name", + "dataType": "STRING" + }, + { + "name": "comment", + "dataType": "STRING" + }, + { + "name": "regionkey", + "dataType": "LONG" + } + ], + "dateTimeFieldSpecs": [ + { + "name": "updated_at_seconds", + "dataType": "LONG", + "defaultNullValue": 0, + "format": "1:SECONDS:EPOCH", + "transformFunction": "toEpochSeconds(updated_at)", + "granularity": "1:SECONDS" + } + ] +} diff --git a/plugin/trino-pinot/src/test/resources/region_realtimeSpec.json b/plugin/trino-pinot/src/test/resources/region_realtimeSpec.json new file mode 100644 index 000000000000..2fa3cdea0320 --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/region_realtimeSpec.json @@ -0,0 +1,46 @@ +{ + "tableName": "region", + "tableType": "REALTIME", + "segmentsConfig": { + "timeColumnName": "updated_at_seconds", + "retentionTimeUnit": "DAYS", + "retentionTimeValue": "365", + "segmentPushType": "APPEND", + "segmentPushFrequency": "daily", + "segmentAssignmentStrategy": "BalanceNumSegmentAssignmentStrategy", + "schemaName": "region", + "replicasPerPartition": "1" + }, + "tenants": { + "broker": "DefaultTenant", + "server": "DefaultTenant" + }, + "tableIndexConfig": { + "loadMode": "MMAP", + "noDictionaryColumns": [], + "sortedColumn": [ + "updated_at_seconds" + ], + "aggregateMetrics": "false", + "nullHandlingEnabled": "true", + "streamConfigs": { + "streamType": "kafka", + "stream.kafka.consumer.type": "lowLevel", + "stream.kafka.topic.name": "region", + "stream.kafka.decoder.class.name": "org.apache.pinot.plugin.inputformat.avro.confluent.KafkaConfluentSchemaRegistryAvroMessageDecoder", + "stream.kafka.consumer.factory.class.name": "org.apache.pinot.plugin.stream.kafka20.KafkaConsumerFactory", + "stream.kafka.decoder.prop.schema.registry.rest.url": "http://schema-registry:8081", + "stream.kafka.zk.broker.url": "zookeeper:2181/", + "stream.kafka.broker.list": "kafka:9092", + "realtime.segment.flush.threshold.time": "1m", + "realtime.segment.flush.threshold.size": "0", + "realtime.segment.flush.desired.size": "1M", + "isolation.level": "read_committed", + "stream.kafka.consumer.prop.auto.offset.reset": "smallest", + "stream.kafka.consumer.prop.group.id": "pinot_region" + } + }, + "metadata": { + "customConfigs": {} + } +} diff --git a/plugin/trino-pinot/src/test/resources/region_schema.json b/plugin/trino-pinot/src/test/resources/region_schema.json new file mode 100644 index 000000000000..65f3dee8c8b9 --- /dev/null +++ b/plugin/trino-pinot/src/test/resources/region_schema.json @@ -0,0 +1,27 @@ +{ + "schemaName": "region", + "dimensionFieldSpecs": [ + { + "name": "regionkey", + "dataType": "LONG" + }, + { + "name": "name", + "dataType": "STRING" + }, + { + "name": "comment", + "dataType": "STRING" + } + ], + "dateTimeFieldSpecs": [ + { + "name": "updated_at_seconds", + "dataType": "LONG", + "defaultNullValue": 0, + "format": "1:SECONDS:EPOCH", + "transformFunction": "toEpochSeconds(updated_at)", + "granularity": "1:SECONDS" + } + ] +} diff --git a/plugin/trino-postgresql/pom.xml b/plugin/trino-postgresql/pom.xml index 4cf1839eefab..f67c531ce733 100644 --- a/plugin/trino-postgresql/pom.xml +++ b/plugin/trino-postgresql/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index ebf2ed68e8cb..00dc909757d1 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -309,8 +309,8 @@ public PostgreSqlClient( .map("$divide(left: integer_type, right: integer_type)").to("left / right") .map("$modulus(left: integer_type, right: integer_type)").to("left % right") .map("$negate(value: integer_type)").to("-value") - .map("$like_pattern(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") - .map("$like_pattern(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape") + .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") + .map("$like(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape") .map("$not($is_null(value))").to("value IS NOT NULL") .map("$not(value: boolean)").to("NOT value") .map("$is_null(value)").to("value IS NULL") diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/TypeUtils.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/TypeUtils.java index 10e6059f9cb5..6020a4e38732 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/TypeUtils.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/TypeUtils.java @@ -208,13 +208,11 @@ private static Object trinoNativeToJdbcObject(ConnectorSession session, Type tri long millisUtc = unpackMillisUtc((long) trinoNative); return new Timestamp(millisUtc); } - else { - LongTimestampWithTimeZone value = (LongTimestampWithTimeZone) trinoNative; - long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); - long nanosOfSecond = floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND - + value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND; - return OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); - } + LongTimestampWithTimeZone value = (LongTimestampWithTimeZone) trinoNative; + long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); + long nanosOfSecond = floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND + + value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND; + return OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); } if (trinoType instanceof VarcharType || trinoType instanceof CharType) { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTypeMapping.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTypeMapping.java index d3bdc6be9bc0..90c19b65b2c9 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTypeMapping.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTypeMapping.java @@ -1813,9 +1813,7 @@ public static DataType timestampWithTimeZoneDataType(int precisio if (insertWithTrino) { return trinoTimestampWithTimeZoneDataType(precision); } - else { - return postgreSqlTimestampWithTimeZoneDataType(precision); - } + return postgreSqlTimestampWithTimeZoneDataType(precision); } public static DataType trinoTimestampWithTimeZoneDataType(int precision) @@ -1848,9 +1846,7 @@ public static DataType> arrayOfTimestampWithTimeZoneDataType if (insertWithTrino) { return arrayDataType(trinoTimestampWithTimeZoneDataType(precision)); } - else { - return arrayDataType(postgreSqlTimestampWithTimeZoneDataType(precision), format("timestamptz(%d)[]", precision)); - } + return arrayDataType(postgreSqlTimestampWithTimeZoneDataType(precision), format("timestamptz(%d)[]", precision)); } private Session sessionWithArrayAsArray() diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java index f84cb8e46dfc..afc5f4c08f67 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java @@ -129,7 +129,7 @@ public Properties getProperties() public String getJdbcUrl() { - return format("jdbc:postgresql://%s:%s/%s", dockerContainer.getContainerIpAddress(), dockerContainer.getMappedPort(POSTGRESQL_PORT), DATABASE); + return format("jdbc:postgresql://%s:%s/%s", dockerContainer.getHost(), dockerContainer.getMappedPort(POSTGRESQL_PORT), DATABASE); } @Override diff --git a/plugin/trino-prometheus/pom.xml b/plugin/trino-prometheus/pom.xml index 0a338f6df6a1..9cc9e3ac5f6b 100644 --- a/plugin/trino-prometheus/pom.xml +++ b/plugin/trino-prometheus/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java index 874142bd92b4..f596f991502a 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java @@ -143,9 +143,7 @@ public long getLong(int field) int offsetMinutes = dateTime.atZone(ZoneId.systemDefault()).getOffset().getTotalSeconds() / 60; return packDateTimeWithZone(dateTime.toEpochMilli(), offsetMinutes); } - else { - throw new TrinoException(NOT_SUPPORTED, "Unsupported type " + getType(field)); - } + throw new TrinoException(NOT_SUPPORTED, "Unsupported type " + getType(field)); } @Override @@ -264,17 +262,15 @@ private static Object readObject(Type type, Block block, int position) Type elementType = ((ArrayType) type).getElementType(); return getArrayFromBlock(elementType, block.getObject(position, Block.class)); } - else if (type instanceof MapType) { + if (type instanceof MapType) { return getMapFromBlock(type, block.getObject(position, Block.class)); } - else { - if (type.getJavaType() == Slice.class) { - Slice slice = (Slice) requireNonNull(TypeUtils.readNativeValue(type, block, position)); - return (type instanceof VarcharType) ? slice.toStringUtf8() : slice.getBytes(); - } - - return TypeUtils.readNativeValue(type, block, position); + if (type.getJavaType() == Slice.class) { + Slice slice = (Slice) requireNonNull(TypeUtils.readNativeValue(type, block, position)); + return (type instanceof VarcharType) ? slice.toStringUtf8() : slice.getBytes(); } + + return TypeUtils.readNativeValue(type, block, position); } private static List getArrayFromBlock(Type elementType, Block block) diff --git a/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/PrometheusServer.java b/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/PrometheusServer.java index bf22307340f6..df938c126f6c 100644 --- a/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/PrometheusServer.java +++ b/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/PrometheusServer.java @@ -59,7 +59,7 @@ public PrometheusServer(String version, boolean enableBasicAuth) public URI getUri() { - return URI.create("http://" + dockerContainer.getContainerIpAddress() + ":" + dockerContainer.getMappedPort(PROMETHEUS_PORT) + "/"); + return URI.create("http://" + dockerContainer.getHost() + ":" + dockerContainer.getMappedPort(PROMETHEUS_PORT) + "/"); } @Override diff --git a/plugin/trino-raptor-legacy/pom.xml b/plugin/trino-raptor-legacy/pom.xml index 6d73f106bbe9..10d0d254e025 100644 --- a/plugin/trino-raptor-legacy/pom.xml +++ b/plugin/trino-raptor-legacy/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java index a264fed2dfc1..5077b5d4bc1e 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java @@ -118,27 +118,25 @@ private static Object getNativeContainerValue(Type type, Block block, int positi if (block.isNull(position)) { return null; } - else if (type.getJavaType() == boolean.class) { + if (type.getJavaType() == boolean.class) { return type.getBoolean(block, position); } - else if (type.getJavaType() == long.class) { + if (type.getJavaType() == long.class) { return type.getLong(block, position); } - else if (type.getJavaType() == double.class) { + if (type.getJavaType() == double.class) { return type.getDouble(block, position); } - else if (type.getJavaType() == Slice.class) { + if (type.getJavaType() == Slice.class) { return type.getSlice(block, position); } - else if (type.getJavaType() == Block.class) { + if (type.getJavaType() == Block.class) { return type.getObject(block, position); } - else if (type.getJavaType() == Int128.class) { + if (type.getJavaType() == Int128.class) { return type.getObject(block, position); } - else { - throw new AssertionError("Unimplemented type: " + type); - } + throw new AssertionError("Unimplemented type: " + type); } private static Object nativeContainerToOrcValue(Type type, Object nativeValue) diff --git a/plugin/trino-redis/pom.xml b/plugin/trino-redis/pom.xml index 1438ee2cd94b..482928544471 100644 --- a/plugin/trino-redis/pom.xml +++ b/plugin/trino-redis/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisRecordCursor.java b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisRecordCursor.java index c32b03fcb4e3..b9cc3a1d6517 100644 --- a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisRecordCursor.java +++ b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisRecordCursor.java @@ -379,19 +379,17 @@ private void setPushdownKeys() log.debug("Set pushdown keys %s with single value", keys.toString()); return; } - else { - ValueSet valueSet = domain.getValues(); - if (valueSet instanceof SortedRangeSet) { - Ranges ranges = ((SortedRangeSet) valueSet).getRanges(); - List rangeList = ranges.getOrderedRanges(); - if (rangeList.stream().allMatch(Range::isSingleValue)) { - keys = rangeList.stream() - .map(range -> ((Slice) range.getSingleValue()).toStringUtf8()) - .filter(str -> keyStringPrefix.isEmpty() || str.contains(keyStringPrefix)) - .collect(toList()); - log.debug("Set pushdown keys %s with sorted range values", keys.toString()); - return; - } + ValueSet valueSet = domain.getValues(); + if (valueSet instanceof SortedRangeSet) { + Ranges ranges = ((SortedRangeSet) valueSet).getRanges(); + List rangeList = ranges.getOrderedRanges(); + if (rangeList.stream().allMatch(Range::isSingleValue)) { + keys = rangeList.stream() + .map(range -> ((Slice) range.getSingleValue()).toStringUtf8()) + .filter(str -> keyStringPrefix.isEmpty() || str.contains(keyStringPrefix)) + .collect(toList()); + log.debug("Set pushdown keys %s with sorted range values", keys.toString()); + return; } } } diff --git a/plugin/trino-redis/src/test/java/io/trino/plugin/redis/util/RedisServer.java b/plugin/trino-redis/src/test/java/io/trino/plugin/redis/util/RedisServer.java index 870b5ac7c434..d429cd39f4b4 100644 --- a/plugin/trino-redis/src/test/java/io/trino/plugin/redis/util/RedisServer.java +++ b/plugin/trino-redis/src/test/java/io/trino/plugin/redis/util/RedisServer.java @@ -44,12 +44,12 @@ public RedisServer(String version, boolean setAccessControl) if (setAccessControl) { container.withCommand("redis-server", "--requirepass", PASSWORD); container.start(); - jedisPool = new JedisPool(container.getContainerIpAddress(), container.getMappedPort(PORT), null, PASSWORD); + jedisPool = new JedisPool(container.getHost(), container.getMappedPort(PORT), null, PASSWORD); jedisPool.getResource().aclSetUser(USER, "on", ">" + PASSWORD, "~*:*", "+@all"); } else { container.start(); - jedisPool = new JedisPool(container.getContainerIpAddress(), container.getMappedPort(PORT)); + jedisPool = new JedisPool(container.getHost(), container.getMappedPort(PORT)); } } @@ -65,7 +65,7 @@ public void destroyJedisPool() public HostAndPort getHostAndPort() { - return HostAndPort.fromParts(container.getContainerIpAddress(), container.getMappedPort(PORT)); + return HostAndPort.fromParts(container.getHost(), container.getMappedPort(PORT)); } @Override diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index af07aa05c6dc..3878e457a57e 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-resource-group-managers/pom.xml b/plugin/trino-resource-group-managers/pom.xml index d5a2b87f74ef..323cb0572946 100644 --- a/plugin/trino-resource-group-managers/pom.xml +++ b/plugin/trino-resource-group-managers/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/AbstractResourceConfigurationManager.java b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/AbstractResourceConfigurationManager.java index 867c3f3727bf..2f363a3d0641 100644 --- a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/AbstractResourceConfigurationManager.java +++ b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/AbstractResourceConfigurationManager.java @@ -112,15 +112,13 @@ private void validateSelectors(List groups, SelectorSpec spec StringBuilder fullyQualifiedGroupName = new StringBuilder(); for (ResourceGroupNameTemplate groupName : spec.getGroup().getSegments()) { fullyQualifiedGroupName.append(groupName); - Optional match = groups + ResourceGroupSpec match = groups .stream() .filter(groupSpec -> groupSpec.getName().equals(groupName)) - .findFirst(); - if (match.isEmpty()) { - throw new IllegalArgumentException(format("Selector refers to nonexistent group: %s", fullyQualifiedGroupName)); - } + .findFirst() + .orElseThrow(() -> new IllegalArgumentException(format("Selector refers to nonexistent group: %s", fullyQualifiedGroupName))); fullyQualifiedGroupName.append("."); - groups = match.get().getSubGroups(); + groups = match.getSubGroups(); } } diff --git a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/FlywayMigration.java b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/FlywayMigration.java index 7e1da2367835..b96cd938de5d 100644 --- a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/FlywayMigration.java +++ b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/FlywayMigration.java @@ -32,10 +32,10 @@ private static String getLocation(String configDbUrl) if (configDbUrl.startsWith("jdbc:postgresql")) { return "/db/migration/postgresql"; } - else if (configDbUrl.startsWith("jdbc:oracle")) { + if (configDbUrl.startsWith("jdbc:oracle")) { return "/db/migration/oracle"; } - else if (configDbUrl.startsWith("jdbc:mysql")) { + if (configDbUrl.startsWith("jdbc:mysql")) { return "/db/migration/mysql"; } // validation is not performed in DbResourceGroupConfig because DB backed diff --git a/plugin/trino-session-property-managers/pom.xml b/plugin/trino-session-property-managers/pom.xml index 0a265f6ce9a9..114b30b352ee 100644 --- a/plugin/trino-session-property-managers/pom.xml +++ b/plugin/trino-session-property-managers/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-singlestore/pom.xml b/plugin/trino-singlestore/pom.xml index 98389a060bc7..e8fa6310abb8 100644 --- a/plugin/trino-singlestore/pom.xml +++ b/plugin/trino-singlestore/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestingSingleStoreServer.java b/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestingSingleStoreServer.java index ec9e8620629b..8e0ee1acb38e 100644 --- a/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestingSingleStoreServer.java +++ b/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestingSingleStoreServer.java @@ -89,7 +89,7 @@ public String getPassword() @Override public String getJdbcUrl() { - return "jdbc:singlestore://" + getContainerIpAddress() + ":" + getMappedPort(SINGLESTORE_PORT); + return "jdbc:singlestore://" + getHost() + ":" + getMappedPort(SINGLESTORE_PORT); } @Override diff --git a/plugin/trino-sqlserver/pom.xml b/plugin/trino-sqlserver/pom.xml index 3122ee35013e..d1d5317b1561 100644 --- a/plugin/trino-sqlserver/pom.xml +++ b/plugin/trino-sqlserver/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-teradata-functions/pom.xml b/plugin/trino-teradata-functions/pom.xml index 2f82f4980f49..00b56c01c125 100644 --- a/plugin/trino-teradata-functions/pom.xml +++ b/plugin/trino-teradata-functions/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift-api/pom.xml b/plugin/trino-thrift-api/pom.xml index 0b30ec9801c2..9ff321e0aa1d 100644 --- a/plugin/trino-thrift-api/pom.xml +++ b/plugin/trino-thrift-api/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/TrinoThriftBlock.java b/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/TrinoThriftBlock.java index d1e794479d9a..7279ca56741c 100644 --- a/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/TrinoThriftBlock.java +++ b/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/TrinoThriftBlock.java @@ -300,9 +300,7 @@ public static TrinoThriftBlock fromBlock(Block block, Type type) if (BigintType.BIGINT.equals(elementType)) { return TrinoThriftBigintArray.fromBlock(block); } - else { - throw new IllegalArgumentException("Unsupported array block type: " + type); - } + throw new IllegalArgumentException("Unsupported array block type: " + type); } if (type.getBaseName().equals(JSON)) { return TrinoThriftJson.fromBlock(block, type); diff --git a/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/datatypes/TrinoThriftBigintArray.java b/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/datatypes/TrinoThriftBigintArray.java index 525b9c2f7173..56874d0fed87 100644 --- a/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/datatypes/TrinoThriftBigintArray.java +++ b/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/datatypes/TrinoThriftBigintArray.java @@ -21,6 +21,7 @@ import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import javax.annotation.Nullable; @@ -144,12 +145,18 @@ public String toString() public static TrinoThriftBlock fromBlock(Block block) { - checkArgument(block instanceof AbstractArrayBlock, "block is not of an array type"); - AbstractArrayBlock arrayBlock = (AbstractArrayBlock) block; - int positions = arrayBlock.getPositionCount(); + int positions = block.getPositionCount(); if (positions == 0) { return bigintArrayData(new TrinoThriftBigintArray(null, null, null)); } + if (block instanceof RunLengthEncodedBlock && block.isNull(0)) { + boolean[] nulls = new boolean[positions]; + Arrays.fill(nulls, true); + return bigintArrayData(new TrinoThriftBigintArray(nulls, null, null)); + } + checkArgument(block instanceof AbstractArrayBlock, "block is not of an array type"); + AbstractArrayBlock arrayBlock = (AbstractArrayBlock) block; + boolean[] nulls = null; int[] sizes = null; for (int position = 0; position < positions; position++) { diff --git a/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/valuesets/TrinoThriftValueSet.java b/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/valuesets/TrinoThriftValueSet.java index 9ad736cfd565..8741510796b8 100644 --- a/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/valuesets/TrinoThriftValueSet.java +++ b/plugin/trino-thrift-api/src/main/java/io/trino/plugin/thrift/api/valuesets/TrinoThriftValueSet.java @@ -109,21 +109,19 @@ public static TrinoThriftValueSet fromValueSet(ValueSet valueSet) null, null); } - else if (valueSet.getClass() == EquatableValueSet.class) { + if (valueSet.getClass() == EquatableValueSet.class) { return new TrinoThriftValueSet( null, fromEquatableValueSet((EquatableValueSet) valueSet), null); } - else if (valueSet.getClass() == SortedRangeSet.class) { + if (valueSet.getClass() == SortedRangeSet.class) { return new TrinoThriftValueSet( null, null, fromSortedRangeSet((SortedRangeSet) valueSet)); } - else { - throw new IllegalArgumentException("Unknown implementation of a value set: " + valueSet.getClass()); - } + throw new IllegalArgumentException("Unknown implementation of a value set: " + valueSet.getClass()); } private static boolean isExactlyOneNonNull(Object a, Object b, Object c) diff --git a/plugin/trino-thrift-testing-server/pom.xml b/plugin/trino-thrift-testing-server/pom.xml index 54b8f2e1557f..a3f877a4edd3 100644 --- a/plugin/trino-thrift-testing-server/pom.xml +++ b/plugin/trino-thrift-testing-server/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java b/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java index 108434d00f30..b73b6ba0b1fa 100644 --- a/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java +++ b/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java @@ -310,12 +310,10 @@ private static List getSchemaNames(String schemaNameOrNull) if (schemaNameOrNull == null) { return SCHEMAS; } - else if (SCHEMAS.contains(schemaNameOrNull)) { + if (SCHEMAS.contains(schemaNameOrNull)) { return ImmutableList.of(schemaNameOrNull); } - else { - return ImmutableList.of(); - } + return ImmutableList.of(); } private static String getTypeString(TpchColumn column) diff --git a/plugin/trino-thrift/pom.xml b/plugin/trino-thrift/pom.xml index e3ea3f382072..679f3197ae2f 100644 --- a/plugin/trino-thrift/pom.xml +++ b/plugin/trino-thrift/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java index 2978bb66a7c6..36ead2d164ee 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java @@ -262,13 +262,11 @@ private boolean loadAllSplits() statusFuture = toCompletableFuture(nonCancellationPropagating(splitFuture)); return false; } - else { - // no more splits - splitFuture = null; - statusFuture = null; - haveSplits = true; - return true; - } + // no more splits + splitFuture = null; + statusFuture = null; + haveSplits = true; + return true; } private void updateSignalAndStatusFutures() diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java index 78b3bd35346a..56f96a5e9cd5 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java @@ -161,9 +161,7 @@ public Optional resolveIndex(ConnectorSession session, C if (tableMetadata.containsIndexableColumns(indexableColumns)) { return Optional.of(new ConnectorResolvedIndex(new ThriftIndexHandle(tableMetadata.getSchemaTableName(), tupleDomain, session), tupleDomain)); } - else { - return Optional.empty(); - } + return Optional.empty(); } @Override @@ -219,13 +217,8 @@ public Optional> applyProjecti private ThriftTableMetadata getRequiredTableMetadata(SchemaTableName schemaTableName) { - Optional table = tableCache.getUnchecked(schemaTableName); - if (table.isEmpty()) { - throw new TableNotFoundException(schemaTableName); - } - else { - return table.get(); - } + return tableCache.getUnchecked(schemaTableName) + .orElseThrow(() -> new TableNotFoundException(schemaTableName)); } // this method makes actual thrift request and should be called only by cache load method diff --git a/plugin/trino-tpcds/pom.xml b/plugin/trino-tpcds/pom.xml index ae1ab55d3682..6173149a29b4 100644 --- a/plugin/trino-tpcds/pom.xml +++ b/plugin/trino-tpcds/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-tpcds/src/test/java/io/trino/plugin/tpcds/EstimateAssertion.java b/plugin/trino-tpcds/src/test/java/io/trino/plugin/tpcds/EstimateAssertion.java index 1672b72da23d..d907b4a9a035 100644 --- a/plugin/trino-tpcds/src/test/java/io/trino/plugin/tpcds/EstimateAssertion.java +++ b/plugin/trino-tpcds/src/test/java/io/trino/plugin/tpcds/EstimateAssertion.java @@ -72,8 +72,6 @@ private double toDouble(Object object) if (object instanceof Number) { return ((Number) object).doubleValue(); } - else { - throw new UnsupportedOperationException(format("Can't compare with tolerance objects of class %s. Use assertEquals.", object.getClass())); - } + throw new UnsupportedOperationException(format("Can't compare with tolerance objects of class %s. Use assertEquals.", object.getClass())); } } diff --git a/plugin/trino-tpch/pom.xml b/plugin/trino-tpch/pom.xml index 5ff8decd0d2d..8790817e1d01 100644 --- a/plugin/trino-tpch/pom.xml +++ b/plugin/trino-tpch/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchMetadata.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchMetadata.java index 3b3555a012b5..822eea5c2cbb 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchMetadata.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchMetadata.java @@ -288,21 +288,19 @@ private Map, List> getColumnValuesRestrictions(TpchTable> columns = ImmutableSet.copyOf(tpchTable.getColumns()); return asMap(columns, key -> emptyList()); } - else { - Map domains = constraintSummary.getDomains().orElseThrow(); - Optional orderStatusDomain = Optional.ofNullable(domains.get(toColumnHandle(OrderColumn.ORDER_STATUS))); - Optional, List>> allowedColumnValues = orderStatusDomain.map(domain -> { - List allowedValues = ORDER_STATUS_VALUES.stream() - .filter(domain::includesNullableValue) - .collect(toList()); - return avoidTrivialOrderStatusRestriction(allowedValues); - }); - return allowedColumnValues.orElse(emptyMap()); - } + Map domains = constraintSummary.getDomains().orElseThrow(); + Optional orderStatusDomain = Optional.ofNullable(domains.get(toColumnHandle(OrderColumn.ORDER_STATUS))); + Optional, List>> allowedColumnValues = orderStatusDomain.map(domain -> { + List allowedValues = ORDER_STATUS_VALUES.stream() + .filter(domain::includesNullableValue) + .collect(toList()); + return avoidTrivialOrderStatusRestriction(allowedValues); + }); + return allowedColumnValues.orElse(emptyMap()); } private static Map, List> avoidTrivialOrderStatusRestriction(List allowedValues) @@ -310,9 +308,7 @@ private static Map, List> avoidTrivialOrderStatusRestricti if (allowedValues.containsAll(ORDER_STATUS_VALUES)) { return emptyMap(); } - else { - return ImmutableMap.of(OrderColumn.ORDER_STATUS, allowedValues); - } + return ImmutableMap.of(OrderColumn.ORDER_STATUS, allowedValues); } private TableStatistics toTableStatistics(TableStatisticsData tableStatisticsData, TpchTableHandle tpchTableHandle, Map columnHandles) diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchRecordSet.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchRecordSet.java index bb431c75588f..87143ee7b802 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchRecordSet.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchRecordSet.java @@ -253,15 +253,13 @@ private Object getTrinoObject(TpchColumn column, Type type) if (type.getJavaType() == long.class) { return getLong(column); } - else if (type.getJavaType() == double.class) { + if (type.getJavaType() == double.class) { return getDouble(column); } - else if (type.getJavaType() == Slice.class) { + if (type.getJavaType() == Slice.class) { return getSlice(column); } - else { - throw new TrinoException(NOT_SUPPORTED, format("Unsupported column type %s", type.getDisplayName())); - } + throw new TrinoException(NOT_SUPPORTED, format("Unsupported column type %s", type.getDisplayName())); } private TpchColumn getTpchColumn(int field) diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/statistics/StatisticsEstimator.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/statistics/StatisticsEstimator.java index 922f24cf715f..53c9e519f042 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/statistics/StatisticsEstimator.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/statistics/StatisticsEstimator.java @@ -43,25 +43,23 @@ public Optional estimateStats(TpchTable tpchTable, Map partitionColumn = getOnlyElement(columnValuesRestrictions.keySet()); - List partitionValues = columnValuesRestrictions.get(partitionColumn); - TableStatisticsData result = zeroStatistics(tpchTable); - for (Object partitionValue : partitionValues) { - Slice value = checkType(partitionValue, Slice.class, "Only string (Slice) partition values supported for now"); - Optional tableStatisticsData = tableStatisticsDataRepository - .load(schemaName, tpchTable, Optional.of(partitionColumn), Optional.of(value.toStringUtf8())); - if (tableStatisticsData.isEmpty()) { - return Optional.empty(); - } - result = addPartitionStats(result, tableStatisticsData.get(), partitionColumn); + checkArgument(columnValuesRestrictions.size() <= 1, "Can only estimate stats when at most one column has value restrictions"); + TpchColumn partitionColumn = getOnlyElement(columnValuesRestrictions.keySet()); + List partitionValues = columnValuesRestrictions.get(partitionColumn); + TableStatisticsData result = zeroStatistics(tpchTable); + for (Object partitionValue : partitionValues) { + Slice value = checkType(partitionValue, Slice.class, "Only string (Slice) partition values supported for now"); + Optional tableStatisticsData = tableStatisticsDataRepository + .load(schemaName, tpchTable, Optional.of(partitionColumn), Optional.of(value.toStringUtf8())); + if (tableStatisticsData.isEmpty()) { + return Optional.empty(); } - return Optional.of(result); + result = addPartitionStats(result, tableStatisticsData.get(), partitionColumn); } + return Optional.of(result); } private TableStatisticsData addPartitionStats(TableStatisticsData left, TableStatisticsData right, TpchColumn partitionColumn) diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/util/Optionals.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/util/Optionals.java index 5efae1f73a43..f9b0a3a166b5 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/util/Optionals.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/util/Optionals.java @@ -32,11 +32,9 @@ public static Optional combine(Optional left, Optional right, Binar if (left.isPresent() && right.isPresent()) { return Optional.of(combiner.apply(left.get(), right.get())); } - else if (left.isPresent()) { + if (left.isPresent()) { return left; } - else { - return right; - } + return right; } } diff --git a/plugin/trino-tpch/src/test/java/io/trino/plugin/tpch/EstimateAssertion.java b/plugin/trino-tpch/src/test/java/io/trino/plugin/tpch/EstimateAssertion.java index ae350f0ae644..c532ecadfdde 100644 --- a/plugin/trino-tpch/src/test/java/io/trino/plugin/tpch/EstimateAssertion.java +++ b/plugin/trino-tpch/src/test/java/io/trino/plugin/tpch/EstimateAssertion.java @@ -78,8 +78,6 @@ private double toDouble(Object object) if (object instanceof Number) { return ((Number) object).doubleValue(); } - else { - throw new UnsupportedOperationException(format("Can't compare with tolerance objects of class %s. Use assertEquals.", object.getClass())); - } + throw new UnsupportedOperationException(format("Can't compare with tolerance objects of class %s. Use assertEquals.", object.getClass())); } } diff --git a/pom.xml b/pom.xml index 88f5a3997a77..bbfbffd67006 100644 --- a/pom.xml +++ b/pom.xml @@ -11,7 +11,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT trino-root Trino @@ -49,7 +49,7 @@ 1.7.4 2.7.7-1 - 4.9.3 + 4.11.1 218 1.7.36 ${dep.airlift.version} @@ -70,6 +70,7 @@ 4.14.0 7.1.4 0.14.0 + 4.7.2 65 @@ -1166,7 +1167,7 @@ com.linkedin.calcite calcite-core - 1.21.0.151 + 1.21.0.152 shaded @@ -1690,7 +1691,7 @@ org.jdbi jdbi3-bom pom - 3.23.0 + 3.32.0 import diff --git a/service/trino-proxy/pom.xml b/service/trino-proxy/pom.xml index 617a55eb149e..785035a1509c 100644 --- a/service/trino-proxy/pom.xml +++ b/service/trino-proxy/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/service/trino-verifier/pom.xml b/service/trino-verifier/pom.xml index 7b6e8044d0e4..58da02804803 100644 --- a/service/trino-verifier/pom.xml +++ b/service/trino-verifier/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java b/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java index 83394c637ae4..55320c845dcb 100644 --- a/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java +++ b/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java @@ -110,7 +110,7 @@ public Query shadowQuery(Query query) if (statement instanceof CreateTableAsSelect) { return rewriteCreateTableAsSelect(connection, query, (CreateTableAsSelect) statement); } - else if (statement instanceof Insert) { + if (statement instanceof Insert) { return rewriteInsertQuery(connection, query, (Insert) statement); } } diff --git a/service/trino-verifier/src/main/java/io/trino/verifier/Validator.java b/service/trino-verifier/src/main/java/io/trino/verifier/Validator.java index 0f9160ad0ed8..9fc68896cc1e 100644 --- a/service/trino-verifier/src/main/java/io/trino/verifier/Validator.java +++ b/service/trino-verifier/src/main/java/io/trino/verifier/Validator.java @@ -321,7 +321,7 @@ private static QueryResult setup(Query query, List preQueryResults, if (queryResult.getState() == State.TIMEOUT) { return queryResult; } - else if (queryResult.getState() != State.SUCCESS) { + if (queryResult.getState() != State.SUCCESS) { return new QueryResult(State.FAILED_TO_SETUP, queryResult.getException(), queryResult.getWallTime(), queryResult.getCpuTime(), queryResult.getQueryId(), ImmutableList.of(), ImmutableList.of()); } } @@ -812,10 +812,8 @@ private static boolean isClose(double a, double b, double epsilon) if (a == 0 || b == 0 || diff < Float.MIN_NORMAL) { return diff < (epsilon * Float.MIN_NORMAL); } - else { - // use relative error - return diff / Math.min((absA + absB), Float.MAX_VALUE) < epsilon; - } + // use relative error + return diff / Math.min((absA + absB), Float.MAX_VALUE) < epsilon; } @VisibleForTesting @@ -861,9 +859,7 @@ public String toString() if (changed == Changed.ADDED) { return "+ " + row; } - else { - return "- " + row; - } + return "- " + row; } @Override diff --git a/service/trino-verifier/src/main/java/io/trino/verifier/Verifier.java b/service/trino-verifier/src/main/java/io/trino/verifier/Verifier.java index 00acc695d19e..8aac11541d3f 100644 --- a/service/trino-verifier/src/main/java/io/trino/verifier/Verifier.java +++ b/service/trino-verifier/src/main/java/io/trino/verifier/Verifier.java @@ -220,9 +220,7 @@ private boolean isCheckCorrectness(QueryPair query) // If so disable correctness checking return false; } - else { - return config.isCheckCorrectnessEnabled(); - } + return config.isCheckCorrectnessEnabled(); } private VerifierQueryEvent buildEvent(Validator validator) diff --git a/testing/trino-benchmark/pom.xml b/testing/trino-benchmark/pom.xml index 7473aee7f05f..6a3f69524878 100644 --- a/testing/trino-benchmark/pom.xml +++ b/testing/trino-benchmark/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildAndJoinBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildAndJoinBenchmark.java index ab6ab4b21f0f..7653b8444dc0 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildAndJoinBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildAndJoinBenchmark.java @@ -43,6 +43,7 @@ import static io.trino.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static io.trino.benchmark.BenchmarkQueryRunner.createLocalQueryRunnerHashEnabled; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -140,14 +141,12 @@ protected List createDrivers(TaskContext taskContext) hashChannel = OptionalInt.of(sourceTypes.size() - 1); } - OperatorFactory joinOperator = operatorFactories.innerJoin( + OperatorFactory joinOperator = operatorFactories.spillingJoin( + innerJoin(false, false), 2, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - true, sourceTypes, Ints.asList(0), hashChannel, diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildBenchmark.java index 22c20c8b9e16..9235295fe0aa 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashBuildBenchmark.java @@ -41,6 +41,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static java.util.Objects.requireNonNull; @@ -99,14 +100,12 @@ protected List createDrivers(TaskContext taskContext) // empty join so build finishes ImmutableList.Builder joinDriversBuilder = ImmutableList.builder(); joinDriversBuilder.add(new ValuesOperatorFactory(0, new PlanNodeId("values"), ImmutableList.of())); - OperatorFactory joinOperator = operatorFactories.innerJoin( + OperatorFactory joinOperator = operatorFactories.spillingJoin( + innerJoin(false, false), 2, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - true, ImmutableList.of(BIGINT), Ints.asList(0), OptionalInt.empty(), diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashJoinBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashJoinBenchmark.java index 2290bd67a268..c43edde2055c 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashJoinBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashJoinBenchmark.java @@ -45,6 +45,7 @@ import static io.trino.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static io.trino.execution.executor.PrioritizedSplitRunner.SPLIT_RUN_QUANTA; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; +import static io.trino.operator.OperatorFactories.JoinOperatorType.innerJoin; import static io.trino.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static java.util.Objects.requireNonNull; @@ -109,14 +110,12 @@ protected List createDrivers(TaskContext taskContext) List lineItemTypes = getColumnTypes("lineitem", "orderkey", "quantity"); OperatorFactory lineItemTableScan = createTableScanOperator(0, new PlanNodeId("test"), "lineitem", "orderkey", "quantity"); - OperatorFactory joinOperator = operatorFactories.innerJoin( + OperatorFactory joinOperator = operatorFactories.spillingJoin( + innerJoin(false, false), 1, new PlanNodeId("test"), lookupSourceFactoryManager, false, - false, - false, - true, lineItemTypes, Ints.asList(0), OptionalInt.empty(), diff --git a/testing/trino-benchto-benchmarks/pom.xml b/testing/trino-benchto-benchmarks/pom.xml index 6efdcbe5b1ca..4ebae96dcdfc 100644 --- a/testing/trino-benchto-benchmarks/pom.xml +++ b/testing/trino-benchto-benchmarks/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-faulttolerant-tests/pom.xml b/testing/trino-faulttolerant-tests/pom.xml index 6e22fc400b91..996138c69c25 100644 --- a/testing/trino-faulttolerant-tests/pom.xml +++ b/testing/trino-faulttolerant-tests/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/BaseIcebergFailureRecoveryTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/BaseIcebergFailureRecoveryTest.java index d85708803c8d..861d5411d0f2 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/BaseIcebergFailureRecoveryTest.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/BaseIcebergFailureRecoveryTest.java @@ -44,13 +44,6 @@ protected boolean areWriteRetriesSupported() return true; } - @Override - public void testAnalyzeStatistics() - { - assertThatThrownBy(super::testAnalyzeStatistics) - .hasMessageContaining("This connector does not support analyze"); - } - @Override protected void createPartitionedLineitemTable(String tableName, List columns, String partitionColumn) { diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergQueryFailureRecoveryTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergQueryFailureRecoveryTest.java index c79b60398192..9a88ca4ce4df 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergQueryFailureRecoveryTest.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergQueryFailureRecoveryTest.java @@ -51,6 +51,7 @@ protected QueryRunner createQueryRunner( .setInitialTables(requiredTpchTables) .setCoordinatorProperties(coordinatorProperties) .setExtraProperties(configProperties) + .setIcebergProperties(Map.of("iceberg.experimental.extended-statistics.enabled", "true")) .setAdditionalSetup(runner -> { runner.installPlugin(new FileSystemExchangePlugin()); runner.loadExchangeManager("filesystem", getExchangeManagerProperties(minioStorage)); diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergTaskFailureRecoveryTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergTaskFailureRecoveryTest.java index 371d53833e6d..cd2a09ee3745 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergTaskFailureRecoveryTest.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergTaskFailureRecoveryTest.java @@ -51,6 +51,7 @@ protected QueryRunner createQueryRunner( .setInitialTables(requiredTpchTables) .setCoordinatorProperties(coordinatorProperties) .setExtraProperties(configProperties) + .setIcebergProperties(Map.of("iceberg.experimental.extended-statistics.enabled", "true")) .setAdditionalSetup(runner -> { runner.installPlugin(new FileSystemExchangePlugin()); runner.loadExchangeManager("filesystem", getExchangeManagerProperties(minioStorage)); diff --git a/testing/trino-plugin-reader/pom.xml b/testing/trino-plugin-reader/pom.xml index a1da29184f12..31865ed4ad27 100644 --- a/testing/trino-plugin-reader/pom.xml +++ b/testing/trino-plugin-reader/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-product-tests-launcher/pom.xml b/testing/trino-product-tests-launcher/pom.xml index 0829912c833f..b013c7cbbecd 100644 --- a/testing/trino-product-tests-launcher/pom.xml +++ b/testing/trino-product-tests-launcher/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/cli/OptionsPrinter.java b/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/cli/OptionsPrinter.java index 0d361d373cc4..cad061766b2b 100644 --- a/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/cli/OptionsPrinter.java +++ b/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/cli/OptionsPrinter.java @@ -84,7 +84,7 @@ private static String formatOption(Object value, Option annotation) if ((boolean) value) { return annotation.names()[0].replaceFirst("--no-", "--"); } - else if (annotation.negatable()) { + if (annotation.negatable()) { return annotation.names()[0]; } @@ -103,9 +103,7 @@ else if (annotation.negatable()) { if (((Optional) value).isPresent()) { return formatOption(((Optional) value).get(), annotation); } - else { - return null; - } + return null; } if (value instanceof Map) { diff --git a/testing/trino-product-tests/pom.xml b/testing/trino-product-tests/pom.xml index aaefc17edfee..a954e91bc307 100644 --- a/testing/trino-product-tests/pom.xml +++ b/testing/trino-product-tests/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/ImmutableLdapObjectDefinitions.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/ImmutableLdapObjectDefinitions.java index b218ae52d943..a44e4083ab82 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/ImmutableLdapObjectDefinitions.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/ImmutableLdapObjectDefinitions.java @@ -78,10 +78,8 @@ public static LdapObjectDefinition buildLdapGroupObject(String groupName, String if (childGroupNames.isPresent()) { return buildLdapGroupObject(groupName, AMERICA_DISTINGUISHED_NAME, userName, ASIA_DISTINGUISHED_NAME, childGroupNames, Optional.of(AMERICA_DISTINGUISHED_NAME)); } - else { - return buildLdapGroupObject(groupName, AMERICA_DISTINGUISHED_NAME, userName, ASIA_DISTINGUISHED_NAME, - Optional.empty(), Optional.empty()); - } + return buildLdapGroupObject(groupName, AMERICA_DISTINGUISHED_NAME, userName, ASIA_DISTINGUISHED_NAME, + Optional.empty(), Optional.empty()); } public static LdapObjectDefinition buildLdapGroupObject(String groupName, String groupOrganizationName, @@ -97,15 +95,13 @@ public static LdapObjectDefinition buildLdapGroupObject(String groupName, String .setObjectClasses(Arrays.asList("groupOfNames")) .build(); } - else { - return LdapObjectDefinition.builder(groupName) - .setDistinguishedName(format("cn=%s,%s", groupName, groupOrganizationName)) - .setAttributes(ImmutableMap.of( - "cn", groupName, - "member", format("uid=%s,%s", userName, userOrganizationName))) - .setObjectClasses(Arrays.asList("groupOfNames")) - .build(); - } + return LdapObjectDefinition.builder(groupName) + .setDistinguishedName(format("cn=%s,%s", groupName, groupOrganizationName)) + .setAttributes(ImmutableMap.of( + "cn", groupName, + "member", format("uid=%s,%s", userName, userOrganizationName))) + .setObjectClasses(Arrays.asList("groupOfNames")) + .build(); } public static LdapObjectDefinition buildLdapUserObject(String userName, Optional> groupNames, String password) @@ -114,10 +110,8 @@ public static LdapObjectDefinition buildLdapUserObject(String userName, Optional return buildLdapUserObject(userName, ASIA_DISTINGUISHED_NAME, groupNames, Optional.of(AMERICA_DISTINGUISHED_NAME), password); } - else { - return buildLdapUserObject(userName, ASIA_DISTINGUISHED_NAME, - Optional.empty(), Optional.empty(), password); - } + return buildLdapUserObject(userName, ASIA_DISTINGUISHED_NAME, + Optional.empty(), Optional.empty(), password); } public static LdapObjectDefinition buildLdapUserObject(String userName, String userOrganizationName, @@ -134,16 +128,14 @@ public static LdapObjectDefinition buildLdapUserObject(String userName, String u .setModificationAttributes(getAttributes(groupNames.get(), groupOrganizationName.get(), MEMBER_OF)) .build(); } - else { - return LdapObjectDefinition.builder(userName) - .setDistinguishedName(format("uid=%s,%s", userName, userOrganizationName)) - .setAttributes(ImmutableMap.of( - "cn", userName, - "sn", userName, - "userPassword", password)) - .setObjectClasses(Arrays.asList("person", "inetOrgPerson")) - .build(); - } + return LdapObjectDefinition.builder(userName) + .setDistinguishedName(format("uid=%s,%s", userName, userOrganizationName)) + .setAttributes(ImmutableMap.of( + "cn", userName, + "sn", userName, + "userPassword", password)) + .setObjectClasses(Arrays.asList("person", "inetOrgPerson")) + .build(); } private static ImmutableMap> getAttributes(List groupNames, String groupOrganizationName, String relation) diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksCheckpointsCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksCheckpointsCompatibility.java index 1c8715432e1e..f1e08a08f218 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksCheckpointsCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksCheckpointsCompatibility.java @@ -21,6 +21,8 @@ import com.google.inject.name.Named; import io.trino.tempto.BeforeTestWithContext; import io.trino.tempto.assertions.QueryAssert; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; @@ -31,16 +33,19 @@ import static io.trino.tempto.assertions.QueryAssert.Row.row; import static io.trino.tempto.assertions.QueryAssert.assertThat; import static io.trino.tests.product.TestGroups.DELTA_LAKE_DATABRICKS; +import static io.trino.tests.product.TestGroups.DELTA_LAKE_EXCLUDE_73; import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS; import static io.trino.tests.product.deltalake.TransactionLogAssertions.assertLastEntryIsCheckpointed; import static io.trino.tests.product.deltalake.TransactionLogAssertions.assertTransactionLogVersion; import static io.trino.tests.product.deltalake.util.DeltaLakeTestUtils.DATABRICKS_104_RUNTIME_VERSION; +import static io.trino.tests.product.deltalake.util.DeltaLakeTestUtils.DATABRICKS_91_RUNTIME_VERSION; import static io.trino.tests.product.deltalake.util.DeltaLakeTestUtils.getDatabricksRuntimeVersion; import static io.trino.tests.product.hive.util.TemporaryHiveTable.randomTableSuffix; import static io.trino.tests.product.utils.QueryExecutors.onDelta; import static io.trino.tests.product.utils.QueryExecutors.onTrino; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestDeltaLakeDatabricksCheckpointsCompatibility extends BaseTestDeltaLakeS3Storage @@ -378,6 +383,233 @@ private void testCheckpointNullStatisticsForRowType(Consumer sqlExecutor } } + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) + public void testTrinoWriteStatsAsJsonDisabled() + { + String tableName = "test_dl_checkpoints_write_stats_as_json_disabled_trino_" + randomTableSuffix(); + testWriteStatsAsJsonDisabled(sql -> onTrino().executeQuery(sql), tableName, "delta.default." + tableName); + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) + public void testDatabricksWriteStatsAsJsonDisabled() + { + String tableName = "test_dl_checkpoints_write_stats_as_json_disabled_databricks_" + randomTableSuffix(); + testWriteStatsAsJsonDisabled(sql -> onDelta().executeQuery(sql), tableName, "default." + tableName); + } + + private void testWriteStatsAsJsonDisabled(Consumer sqlExecutor, String tableName, String qualifiedTableName) + { + onDelta().executeQuery(format( + "CREATE TABLE default.%s" + + "(a_number INT, a_string STRING) " + + "USING DELTA " + + "PARTITIONED BY (a_number) " + + "LOCATION 's3://%s/databricks-compatibility-test-%1$s' " + + "TBLPROPERTIES (" + + " delta.checkpointInterval = 5, " + + " delta.checkpoint.writeStatsAsJson = false)", + tableName, bucketName)); + + try { + sqlExecutor.accept("INSERT INTO " + qualifiedTableName + " VALUES (1,'ala')"); + + assertThat(onTrino().executeQuery("SHOW STATS FOR delta.default." + tableName)) + .containsOnly(ImmutableList.of( + row("a_number", null, 1.0, 0.0, null, null, null), + row("a_string", null, null, 0.0, null, null, null), + row(null, null, null, null, 1.0, null, null))); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) + public void testTrinoWriteStatsAsStructDisabled() + { + String tableName = "test_dl_checkpoints_write_stats_as_struct_disabled_trino_" + randomTableSuffix(); + testWriteStatsAsStructDisabled(sql -> onTrino().executeQuery(sql), tableName, "delta.default." + tableName); + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) + public void testDatabricksWriteStatsAsStructDisabled() + { + String tableName = "test_dl_checkpoints_write_stats_as_struct_disabled_databricks_" + randomTableSuffix(); + testWriteStatsAsStructDisabled(sql -> onDelta().executeQuery(sql), tableName, "default." + tableName); + } + + private void testWriteStatsAsStructDisabled(Consumer sqlExecutor, String tableName, String qualifiedTableName) + { + onDelta().executeQuery(format( + "CREATE TABLE default.%s" + + "(a_number INT, a_string STRING) " + + "USING DELTA " + + "PARTITIONED BY (a_number) " + + "LOCATION 's3://%s/databricks-compatibility-test-%1$s' " + + "TBLPROPERTIES (" + + " delta.checkpointInterval = 1, " + + " delta.checkpoint.writeStatsAsJson = false, " + // Disable json stats to avoid merging statistics with 'stats' field + " delta.checkpoint.writeStatsAsStruct = false)", + tableName, bucketName)); + + try { + sqlExecutor.accept("INSERT INTO " + qualifiedTableName + " VALUES (1,'ala')"); + + assertThat(onTrino().executeQuery("SHOW STATS FOR delta.default." + tableName)) + .containsOnly(ImmutableList.of( + row("a_number", null, null, null, null, null, null), + row("a_string", null, null, null, null, null, null), + row(null, null, null, null, null, null, null))); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}, dataProvider = "testTrinoCheckpointWriteStatsAsJson") + public void testTrinoWriteStatsAsJsonEnabled(String type, String inputValue, Double nullsFraction, Object statsValue) + { + String tableName = "test_dl_checkpoints_write_stats_as_json_enabled_trino_" + randomTableSuffix(); + testWriteStatsAsJsonEnabled(sql -> onTrino().executeQuery(sql), tableName, "delta.default." + tableName, type, inputValue, nullsFraction, statsValue); + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}, dataProvider = "testDeltaCheckpointWriteStatsAsJson") + public void testDatabricksWriteStatsAsJsonEnabled(String type, String inputValue, Double nullsFraction, Object statsValue) + { + String tableName = "test_dl_checkpoints_write_stats_as_json_enabled_databricks_" + randomTableSuffix(); + testWriteStatsAsJsonEnabled(sql -> onDelta().executeQuery(sql), tableName, "default." + tableName, type, inputValue, nullsFraction, statsValue); + } + + private void testWriteStatsAsJsonEnabled(Consumer sqlExecutor, String tableName, String qualifiedTableName, String type, String inputValue, Double nullsFraction, Object statsValue) + { + String createTableSql = format( + "CREATE TABLE default.%s" + + "(col %s) " + + "USING DELTA " + + "LOCATION 's3://%s/databricks-compatibility-test-%1$s' " + + "TBLPROPERTIES (" + + " delta.checkpointInterval = 2, " + + " delta.checkpoint.writeStatsAsJson = false, " + + " delta.checkpoint.writeStatsAsStruct = true)", + tableName, type, bucketName); + + if (getDatabricksRuntimeVersion().equals(DATABRICKS_91_RUNTIME_VERSION) && type.equals("struct")) { + assertThatThrownBy(() -> onDelta().executeQuery(createTableSql)).hasStackTraceContaining("ParseException"); + throw new SkipException("New runtime version covers the type"); + } + + onDelta().executeQuery(createTableSql); + + try { + sqlExecutor.accept("INSERT INTO " + qualifiedTableName + " SELECT " + inputValue); + sqlExecutor.accept("INSERT INTO " + qualifiedTableName + " SELECT " + inputValue); + + // SET TBLPROPERTIES increments checkpoint + onDelta().executeQuery("" + + "ALTER TABLE default." + tableName + " SET TBLPROPERTIES (" + + "'delta.checkpoint.writeStatsAsJson' = true, " + + "'delta.checkpoint.writeStatsAsStruct' = false)"); + + sqlExecutor.accept("INSERT INTO " + qualifiedTableName + " SELECT " + inputValue); + + assertThat(onTrino().executeQuery("SHOW STATS FOR delta.default." + tableName)) + .containsOnly(ImmutableList.of( + row("col", null, null, nullsFraction, null, statsValue, statsValue), + row(null, null, null, null, 3.0, null, null))); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } + + @DataProvider + public Object[][] testTrinoCheckpointWriteStatsAsJson() + { + return new Object[][] { + {"boolean", "true", 0.0, null}, + {"integer", "1", 0.0, "1"}, + {"tinyint", "2", 0.0, "2"}, + {"smallint", "3", 0.0, "3"}, + {"bigint", "1000", 0.0, "1000"}, + {"real", "0.1", 0.0, "0.1"}, + {"double", "1.0", 0.0, "1.0"}, + {"decimal(3,2)", "3.14", 0.0, "3.14"}, + {"decimal(30,1)", "12345", 0.0, "12345.0"}, + {"string", "'test'", 0.0, null}, + {"binary", "X'65683F'", 0.0, null}, + {"date", "date '2021-02-03'", 0.0, "2021-02-03"}, + {"timestamp", "timestamp '2001-08-22 11:04:05.321 UTC'", 0.0, "2001-08-22 11:04:05.321 UTC"}, + {"array", "array[1]", null, null}, + {"map", "map(array['key1', 'key2'], array[1, 2])", null, null}, + {"struct", "cast(row(1) as row(x bigint))", null, null}, + }; + } + + @DataProvider + public Object[][] testDeltaCheckpointWriteStatsAsJson() + { + return new Object[][] { + {"boolean", "true", 0.0, null}, + {"integer", "1", 0.0, "1"}, + {"tinyint", "2", 0.0, "2"}, + {"smallint", "3", 0.0, "3"}, + {"bigint", "1000", 0.0, "1000"}, + {"real", "0.1", 0.0, "0.1"}, + {"double", "1.0", 0.0, "1.0"}, + {"decimal(3,2)", "3.14", 0.0, "3.14"}, + {"decimal(30,1)", "12345", 0.0, "12345.0"}, + {"string", "'test'", 0.0, null}, + {"binary", "X'65683F'", 0.0, null}, + {"date", "date '2021-02-03'", 0.0, "2021-02-03"}, + {"timestamp", "timestamp '2001-08-22 11:04:05.321 UTC'", 0.0, "2001-08-22 11:04:05.321 UTC"}, + {"array", "array(1)", 0.0, null}, + {"map", "map('key1', 1, 'key2', 2)", 0.0, null}, + {"struct", "named_struct('x', 1)", null, null}, + }; + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) + public void testTrinoWriteStatsAsStructEnabled() + { + String tableName = "test_dl_checkpoints_write_stats_as_struct_enabled_trino_" + randomTableSuffix(); + testWriteStatsAsStructEnabled(sql -> onTrino().executeQuery(sql), tableName, "delta.default." + tableName); + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) + public void testDatabricksWriteStatsAsStructEnabled() + { + String tableName = "test_dl_checkpoints_write_stats_as_struct_enabled_databricks_" + randomTableSuffix(); + testWriteStatsAsStructEnabled(sql -> onDelta().executeQuery(sql), tableName, "default." + tableName); + } + + private void testWriteStatsAsStructEnabled(Consumer sqlExecutor, String tableName, String qualifiedTableName) + { + onDelta().executeQuery(format( + "CREATE TABLE default.%s" + + "(a_number INT, a_string STRING) " + + "USING DELTA " + + "PARTITIONED BY (a_number) " + + "LOCATION 's3://%s/databricks-compatibility-test-%1$s' " + + "TBLPROPERTIES (" + + " delta.checkpointInterval = 1, " + + " delta.checkpoint.writeStatsAsJson = false, " + + " delta.checkpoint.writeStatsAsStruct = true)", + tableName, bucketName)); + + try { + sqlExecutor.accept("INSERT INTO " + qualifiedTableName + " VALUES (1,'ala')"); + + assertThat(onTrino().executeQuery("SHOW STATS FOR delta.default." + tableName)) + .containsOnly(ImmutableList.of( + row("a_number", null, 1.0, 0.0, null, null, null), + row("a_string", null, null, 0.0, null, null, null), + row(null, null, null, null, 1.0, null, null))); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } + private void fillWithInserts(String tableName, String values, int toCreate) { for (int i = 0; i < toCreate; i++) { diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/util/DeltaLakeTestUtils.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/util/DeltaLakeTestUtils.java index 81dc76e693d0..1b36ed37eda9 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/util/DeltaLakeTestUtils.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/util/DeltaLakeTestUtils.java @@ -23,6 +23,7 @@ public final class DeltaLakeTestUtils { public static final String DATABRICKS_104_RUNTIME_VERSION = "10.4"; + public static final String DATABRICKS_91_RUNTIME_VERSION = "9.1"; private DeltaLakeTestUtils() {} diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestAllDatatypesFromHiveConnector.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestAllDatatypesFromHiveConnector.java index 90c73e52d1d2..2bf4374d4abc 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestAllDatatypesFromHiveConnector.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestAllDatatypesFromHiveConnector.java @@ -410,9 +410,7 @@ private static TableInstance mutableTableInstanceOf(TableDefinition tableDefi if (tableDefinition.getDatabase().isPresent()) { return mutableTableInstanceOf(tableDefinition, tableDefinition.getDatabase().get()); } - else { - return mutableTableInstanceOf(tableHandleInSchema(tableDefinition)); - } + return mutableTableInstanceOf(tableHandleInSchema(tableDefinition)); } private static TableInstance mutableTableInstanceOf(TableDefinition tableDefinition, String database) diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercion.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercion.java index 7f8a67bb1cd4..265b5d862b75 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercion.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercion.java @@ -719,9 +719,7 @@ private static TableInstance mutableTableInstanceOf(TableDefinition tableDefi if (tableDefinition.getDatabase().isPresent()) { return mutableTableInstanceOf(tableDefinition, tableDefinition.getDatabase().get()); } - else { - return mutableTableInstanceOf(tableHandleInSchema(tableDefinition)); - } + return mutableTableInstanceOf(tableHandleInSchema(tableDefinition)); } private static TableInstance mutableTableInstanceOf(TableDefinition tableDefinition, String database) diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index a4e40d8e109e..e2dfe1f440cb 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -62,6 +62,7 @@ import static io.trino.tests.product.iceberg.TestIcebergSparkCompatibility.CreateMode.CREATE_TABLE_AS_SELECT; import static io.trino.tests.product.iceberg.TestIcebergSparkCompatibility.CreateMode.CREATE_TABLE_WITH_NO_DATA_AND_INSERT; import static io.trino.tests.product.iceberg.TestIcebergSparkCompatibility.StorageFormat.AVRO; +import static io.trino.tests.product.iceberg.util.IcebergTestUtils.getTableLocation; import static io.trino.tests.product.utils.QueryExecutors.onSpark; import static io.trino.tests.product.utils.QueryExecutors.onTrino; import static java.lang.String.format; @@ -2106,6 +2107,40 @@ public void testHandlingPartitionSchemaEvolutionInPartitionMetadata() ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null", "value_day", "null", "value_month", "null"))); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}) + public void testMetadataCompressionCodecGzip() + { + // Verify that Trino can read and write to a table created by Spark + String baseTableName = "test_metadata_compression_codec_gzip" + randomTableSuffix(); + String trinoTableName = trinoTableName(baseTableName); + String sparkTableName = sparkTableName(baseTableName); + + onSpark().executeQuery("CREATE TABLE " + sparkTableName + "(col int) USING iceberg TBLPROPERTIES ('write.metadata.compression-codec'='gzip')"); + onSpark().executeQuery("INSERT INTO " + sparkTableName + " VALUES (1)"); + onTrino().executeQuery("INSERT INTO " + trinoTableName + " VALUES (2)"); + + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(row(1), row(2)); + + // Verify that all metadata file is compressed as Gzip + String tableLocation = getTableLocation(trinoTableName); + List metadataFiles = hdfsClient.listDirectory(tableLocation + "/metadata").stream() + .filter(file -> file.endsWith("metadata.json")) + .collect(toImmutableList()); + Assertions.assertThat(metadataFiles) + .isNotEmpty() + .filteredOn(file -> file.endsWith("gz.metadata.json")) + .isEqualTo(metadataFiles); + + // Change 'write.metadata.compression-codec' to none and insert and select the table in Trino + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " SET TBLPROPERTIES ('write.metadata.compression-codec'='none')"); + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(row(1), row(2)); + + onTrino().executeQuery("INSERT INTO " + trinoTableName + " VALUES (3)"); + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(row(1), row(2), row(3)); + + onSpark().executeQuery("DROP TABLE " + sparkTableName); + } + private void validatePartitioning(String baseTableName, String sparkTableName, List> expectedValues) { List trinoResult = expectedValues.stream().map(m -> @@ -2129,6 +2164,25 @@ private void validatePartitioning(String baseTableName, String sparkTableName, L Assertions.assertThat(partitions).containsAll(sparkResult); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}) + public void testTrinoAnalyze() + { + String baseTableName = "test_trino_analyze_" + randomTableSuffix(); + String trinoTableName = trinoTableName(baseTableName); + String sparkTableName = sparkTableName(baseTableName); + onTrino().executeQuery("DROP TABLE IF EXISTS " + trinoTableName); + onTrino().executeQuery("CREATE TABLE " + trinoTableName + " AS SELECT regionkey, name FROM tpch.tiny.region"); + onTrino().executeQuery("SET SESSION " + TRINO_CATALOG + ".experimental_extended_statistics_enabled = true"); + onTrino().executeQuery("ANALYZE " + trinoTableName); + + // We're not verifying results of ANALYZE (covered by non-product tests), but we're verifying table is readable. + List expected = List.of(row(0, "AFRICA"), row(1, "AMERICA"), row(2, "ASIA"), row(3, "EUROPE"), row(4, "MIDDLE EAST")); + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(expected); + assertThat(onSpark().executeQuery("SELECT * FROM " + sparkTableName)).containsOnly(expected); + + onTrino().executeQuery("DROP TABLE " + trinoTableName); + } + private int calculateMetadataFilesForPartitionedTable(String tableName) { String dataFilePath = onTrino().executeQuery(format("SELECT file_path FROM iceberg.default.\"%s$files\" limit 1", tableName)).row(0).get(0).toString(); diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkDropTableCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkDropTableCompatibility.java index b7fd73a2f053..a11c5599ef49 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkDropTableCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkDropTableCompatibility.java @@ -26,16 +26,14 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; -import static com.google.common.base.Verify.verify; import static io.trino.tests.product.TestGroups.ICEBERG; import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS; import static io.trino.tests.product.hive.Engine.SPARK; import static io.trino.tests.product.hive.Engine.TRINO; import static io.trino.tests.product.hive.util.TemporaryHiveTable.randomTableSuffix; +import static io.trino.tests.product.iceberg.util.IcebergTestUtils.getTableLocation; import static io.trino.tests.product.utils.QueryExecutors.onSpark; import static io.trino.tests.product.utils.QueryExecutors.onTrino; import static java.lang.String.format; @@ -86,18 +84,6 @@ public void testCleanupOnDropTable(Engine tableCreatorEngine, Engine tableDroppe dataFilePaths.forEach(dataFilePath -> assertFileExistence(dataFilePath, expectExists, format("The data file %s removed after dropping the table", dataFilePath))); } - private String getTableLocation(String tableName) - { - Pattern locationPattern = Pattern.compile(".*location = 'hdfs://hadoop-master:9000(.*?)'.*", Pattern.DOTALL); - Matcher m = locationPattern.matcher((String) onTrino().executeQuery("SHOW CREATE TABLE " + tableName).row(0).get(0)); - if (m.find()) { - String location = m.group(1); - verify(!m.find(), "Unexpected second match"); - return location; - } - throw new IllegalStateException("Location not found in SHOW CREATE TABLE result"); - } - private void assertFileExistence(String path, boolean exists, String description) { Assertions.assertThat(hdfsClient.exist(path)).as(description).isEqualTo(exists); diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/util/IcebergTestUtils.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/util/IcebergTestUtils.java new file mode 100644 index 000000000000..f6716a5aa4b1 --- /dev/null +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/util/IcebergTestUtils.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests.product.iceberg.util; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static com.google.common.base.Verify.verify; +import static io.trino.tests.product.utils.QueryExecutors.onTrino; + +public final class IcebergTestUtils +{ + private IcebergTestUtils() {} + + public static String getTableLocation(String tableName) + { + Pattern locationPattern = Pattern.compile(".*location = 'hdfs://hadoop-master:9000(.*?)'.*", Pattern.DOTALL); + Matcher m = locationPattern.matcher((String) onTrino().executeQuery("SHOW CREATE TABLE " + tableName).row(0).get(0)); + if (m.find()) { + String location = m.group(1); + verify(!m.find(), "Unexpected second match"); + return location; + } + throw new IllegalStateException("Location not found in SHOW CREATE TABLE result"); + } +} diff --git a/testing/trino-server-dev/pom.xml b/testing/trino-server-dev/pom.xml index f50604428a32..79936054970f 100644 --- a/testing/trino-server-dev/pom.xml +++ b/testing/trino-server-dev/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-test-jdbc-compatibility-old-driver/pom.xml b/testing/trino-test-jdbc-compatibility-old-driver/pom.xml index 6fe45229c0c5..fdbb90f59ce0 100644 --- a/testing/trino-test-jdbc-compatibility-old-driver/pom.xml +++ b/testing/trino-test-jdbc-compatibility-old-driver/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml @@ -14,7 +14,7 @@ ${project.parent.basedir} - 395-SNAPSHOT + 396-SNAPSHOT diff --git a/testing/trino-test-jdbc-compatibility-old-server/pom.xml b/testing/trino-test-jdbc-compatibility-old-server/pom.xml index d5ea4f7907ef..34fd50dbcc32 100644 --- a/testing/trino-test-jdbc-compatibility-old-server/pom.xml +++ b/testing/trino-test-jdbc-compatibility-old-server/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-containers/pom.xml b/testing/trino-testing-containers/pom.xml index bf6f63e9ad17..204e9232bce3 100644 --- a/testing/trino-testing-containers/pom.xml +++ b/testing/trino-testing-containers/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/wait/strategy/SelectedPortWaitStrategy.java b/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/wait/strategy/SelectedPortWaitStrategy.java index 69d365e1e08e..63da3e61f5f9 100644 --- a/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/wait/strategy/SelectedPortWaitStrategy.java +++ b/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/wait/strategy/SelectedPortWaitStrategy.java @@ -73,7 +73,7 @@ protected void waitUntilReady() // We say "timed out" immediately. Failsafe will propagate this only when timeout reached. throw new ContainerLaunchException(format( "Timed out waiting for container port to open (%s ports: %s should be listening)", - waitStrategyTarget.getContainerIpAddress(), + waitStrategyTarget.getHost(), exposedPorts)); } }); diff --git a/testing/trino-testing-kafka/pom.xml b/testing/trino-testing-kafka/pom.xml index bd8e7738a3fe..3663d28f81fd 100644 --- a/testing/trino-testing-kafka/pom.xml +++ b/testing/trino-testing-kafka/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-kafka/src/main/java/io/trino/testing/kafka/TestingKafka.java b/testing/trino-testing-kafka/src/main/java/io/trino/testing/kafka/TestingKafka.java index fa49c9ed3fc6..7967c4f1226b 100644 --- a/testing/trino-testing-kafka/src/main/java/io/trino/testing/kafka/TestingKafka.java +++ b/testing/trino-testing-kafka/src/main/java/io/trino/testing/kafka/TestingKafka.java @@ -215,7 +215,7 @@ private Future send(KafkaProducer producer, Produce public String getConnectString() { - return kafka.getContainerIpAddress() + ":" + kafka.getMappedPort(KAFKA_PORT); + return kafka.getHost() + ":" + kafka.getMappedPort(KAFKA_PORT); } private KafkaProducer createProducer(Map extraProperties) @@ -242,7 +242,7 @@ private static Properties toProperties(Map map) public String getSchemaRegistryConnectString() { - return "http://" + schemaRegistry.getContainerIpAddress() + ":" + schemaRegistry.getMappedPort(SCHEMA_REGISTRY_PORT); + return "http://" + schemaRegistry.getHost() + ":" + schemaRegistry.getMappedPort(SCHEMA_REGISTRY_PORT); } public Network getNetwork() diff --git a/testing/trino-testing-resources/pom.xml b/testing/trino-testing-resources/pom.xml index 543756998495..6ad2ea2a406a 100644 --- a/testing/trino-testing-resources/pom.xml +++ b/testing/trino-testing-resources/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-services/pom.xml b/testing/trino-testing-services/pom.xml index 369926aa2b59..f5395d51031f 100644 --- a/testing/trino-testing-services/pom.xml +++ b/testing/trino-testing-services/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml index dc522a7244a5..94e0005390b3 100644 --- a/testing/trino-testing/pom.xml +++ b/testing/trino-testing/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java index 29f744740f8a..64203abe6c73 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java @@ -147,27 +147,27 @@ private static ClientSession toClientSession(Session session, URI server, Durati estimates.getCpuTime().ifPresent(e -> resourceEstimates.put(CPU_TIME, e.toString())); estimates.getPeakMemoryBytes().ifPresent(e -> resourceEstimates.put(PEAK_MEMORY, e.toString())); - return new ClientSession( - server, - Optional.of(session.getIdentity().getUser()), - Optional.empty(), - session.getSource().orElse(null), - session.getTraceToken(), - session.getClientTags(), - session.getClientInfo().orElse(null), - session.getCatalog().orElse(null), - session.getSchema().orElse(null), - session.getPath().toString(), - ZoneId.of(session.getTimeZoneKey().getId()), - session.getLocale(), - resourceEstimates.buildOrThrow(), - properties.buildOrThrow(), - session.getPreparedStatements(), - getRoles(session), - session.getIdentity().getExtraCredentials(), - session.getTransactionId().map(Object::toString).orElse(null), - clientRequestTimeout, - true); + return ClientSession.builder() + .server(server) + .principal(Optional.of(session.getIdentity().getUser())) + .source(session.getSource().orElse(null)) + .traceToken(session.getTraceToken()) + .clientTags(session.getClientTags()) + .clientInfo(session.getClientInfo().orElse(null)) + .catalog(session.getCatalog().orElse(null)) + .schema(session.getSchema().orElse(null)) + .path(session.getPath().toString()) + .timeZone(ZoneId.of(session.getTimeZoneKey().getId())) + .locale(session.getLocale()) + .resourceEstimates(resourceEstimates.buildOrThrow()) + .properties(properties.buildOrThrow()) + .preparedStatements(session.getPreparedStatements()) + .roles(getRoles(session)) + .credentials(session.getIdentity().getExtraCredentials()) + .transactionId(session.getTransactionId().map(Object::toString).orElse(null)) + .clientRequestTimeout(clientRequestTimeout) + .compressionDisabled(true) + .build(); } private static Map getRoles(Session session) diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index e2d8812d786b..43475addf41a 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -1890,6 +1890,10 @@ public void testAddColumn() assertQueryFails("ALTER TABLE " + table.getName() + " ADD COLUMN q bad_type", ".* Unknown type 'bad_type' for column 'q'"); assertUpdate("ALTER TABLE " + table.getName() + " ADD COLUMN a varchar(50)"); + // Verify table state after adding a column, but before inserting anything to it + assertQuery( + "SELECT * FROM " + table.getName(), + "VALUES ('first', NULL)"); assertUpdate("INSERT INTO " + table.getName() + " SELECT 'second', 'xxx'", 1); assertQuery( "SELECT x, a FROM " + table.getName(), diff --git a/testing/trino-testing/src/main/java/io/trino/testing/StructuralTestUtil.java b/testing/trino-testing/src/main/java/io/trino/testing/StructuralTestUtil.java index d2ae72fb7d23..62d41149759b 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/StructuralTestUtil.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/StructuralTestUtil.java @@ -132,10 +132,8 @@ public static Block decimalArrayBlockOf(DecimalType type, BigDecimal decimal) long longDecimal = decimal.unscaledValue().longValue(); return arrayBlockOf(type, longDecimal); } - else { - Int128 sliceDecimal = Int128.valueOf(decimal.unscaledValue()); - return arrayBlockOf(type, sliceDecimal); - } + Int128 sliceDecimal = Int128.valueOf(decimal.unscaledValue()); + return arrayBlockOf(type, sliceDecimal); } public static Block decimalMapBlockOf(DecimalType type, BigDecimal decimal) @@ -144,10 +142,8 @@ public static Block decimalMapBlockOf(DecimalType type, BigDecimal decimal) long longDecimal = decimal.unscaledValue().longValue(); return mapBlockOf(type, type, longDecimal, longDecimal); } - else { - Int128 sliceDecimal = Int128.valueOf(decimal.unscaledValue()); - return mapBlockOf(type, type, sliceDecimal, sliceDecimal); - } + Int128 sliceDecimal = Int128.valueOf(decimal.unscaledValue()); + return mapBlockOf(type, type, sliceDecimal, sliceDecimal); } public static MapType mapType(Type keyType, Type valueType) diff --git a/testing/trino-testing/src/main/java/io/trino/testing/tpch/AppendingRecordSet.java b/testing/trino-testing/src/main/java/io/trino/testing/tpch/AppendingRecordSet.java index 4526c244a64e..dce32f5bba8a 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/tpch/AppendingRecordSet.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/tpch/AppendingRecordSet.java @@ -100,9 +100,7 @@ public Type getType(int field) if (field < delegateFieldCount) { return delegate.getType(field); } - else { - return appendedTypes.get(field - delegateFieldCount); - } + return appendedTypes.get(field - delegateFieldCount); } @Override @@ -118,9 +116,7 @@ public boolean getBoolean(int field) if (field < delegateFieldCount) { return delegate.getBoolean(field); } - else { - return (Boolean) appendedValues.get(field - delegateFieldCount); - } + return (Boolean) appendedValues.get(field - delegateFieldCount); } @Override @@ -130,9 +126,7 @@ public long getLong(int field) if (field < delegateFieldCount) { return delegate.getLong(field); } - else { - return (Long) appendedValues.get(field - delegateFieldCount); - } + return (Long) appendedValues.get(field - delegateFieldCount); } @Override @@ -142,9 +136,7 @@ public double getDouble(int field) if (field < delegateFieldCount) { return delegate.getDouble(field); } - else { - return (Double) appendedValues.get(field - delegateFieldCount); - } + return (Double) appendedValues.get(field - delegateFieldCount); } @Override @@ -154,9 +146,7 @@ public Slice getSlice(int field) if (field < delegateFieldCount) { return delegate.getSlice(field); } - else { - return (Slice) appendedValues.get(field - delegateFieldCount); - } + return (Slice) appendedValues.get(field - delegateFieldCount); } @Override @@ -172,9 +162,7 @@ public boolean isNull(int field) if (field < delegateFieldCount) { return delegate.isNull(field); } - else { - return appendedValues.get(field - delegateFieldCount) == null; - } + return appendedValues.get(field - delegateFieldCount) == null; } @Override diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 9465ace52f09..f3b8bdd956af 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 395-SNAPSHOT + 396-SNAPSHOT ../../pom.xml diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestFinalQueryInfo.java b/testing/trino-tests/src/test/java/io/trino/execution/TestFinalQueryInfo.java index 5b5a7ada569e..8c10fd157ddf 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestFinalQueryInfo.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestFinalQueryInfo.java @@ -13,8 +13,6 @@ */ package io.trino.execution; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; import io.airlift.units.Duration; import io.trino.Session; @@ -63,27 +61,16 @@ private static QueryId startQuery(String sql, DistributedQueryRunner queryRunner { OkHttpClient httpClient = new OkHttpClient(); try { - ClientSession clientSession = new ClientSession( - queryRunner.getCoordinator().getBaseUrl(), - Optional.of("user"), - Optional.empty(), - "source", - Optional.empty(), - ImmutableSet.of(), - null, - null, - null, - null, - ZoneId.of("America/Los_Angeles"), - Locale.ENGLISH, - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - null, - new Duration(2, MINUTES), - true); + ClientSession clientSession = ClientSession.builder() + .server(queryRunner.getCoordinator().getBaseUrl()) + .principal(Optional.of("user")) + .source("source") + .timeZone(ZoneId.of("America/Los_Angeles")) + .locale(Locale.ENGLISH) + .transactionId(null) + .clientRequestTimeout(new Duration(2, MINUTES)) + .compressionDisabled(true) + .build(); // start query StatementClient client = newStatementClient(httpClient, clientSession, sql); diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestUserImpersonationAccessControl.java b/testing/trino-tests/src/test/java/io/trino/execution/TestUserImpersonationAccessControl.java index e97c56dbefcd..dc0b3204e590 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestUserImpersonationAccessControl.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestUserImpersonationAccessControl.java @@ -14,7 +14,6 @@ package io.trino.execution; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.airlift.units.Duration; import io.trino.client.ClientSession; import io.trino.client.QueryError; @@ -84,27 +83,16 @@ private QueryError trySelectQuery(String assumedUser) { OkHttpClient httpClient = new OkHttpClient(); try { - ClientSession clientSession = new ClientSession( - getDistributedQueryRunner().getCoordinator().getBaseUrl(), - Optional.of("user"), - Optional.of(assumedUser), - "source", - Optional.empty(), - ImmutableSet.of(), - null, - null, - null, - null, - ZoneId.of("America/Los_Angeles"), - Locale.ENGLISH, - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), - null, - new Duration(2, MINUTES), - true); + ClientSession clientSession = ClientSession.builder() + .server(getDistributedQueryRunner().getCoordinator().getBaseUrl()) + .principal(Optional.of("user")) + .user(Optional.of(assumedUser)) + .source("source") + .timeZone(ZoneId.of("America/Los_Angeles")) + .locale(Locale.ENGLISH) + .clientRequestTimeout(new Duration(2, MINUTES)) + .compressionDisabled(true) + .build(); // start query StatementClient client = newStatementClient(httpClient, clientSession, "SELECT * FROM tpch.tiny.nation"); diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java index 0e45421c7457..04a49b808a52 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java @@ -21,6 +21,9 @@ import io.trino.connector.MockConnectorFactory; import io.trino.connector.MockConnectorPlugin; import io.trino.plugin.blackhole.BlackHolePlugin; +import io.trino.plugin.jdbc.JdbcPlugin; +import io.trino.plugin.jdbc.TestingH2JdbcModule; +import io.trino.plugin.memory.MemoryPlugin; import io.trino.plugin.tpch.TpchPlugin; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.SchemaTableName; @@ -29,12 +32,15 @@ import io.trino.spi.security.SelectedRole; import io.trino.spi.security.TrinoPrincipal; import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DataProviders; import io.trino.testing.DistributedQueryRunner; import io.trino.testing.QueryRunner; +import io.trino.testing.TestingAccessControlManager; import io.trino.testing.TestingAccessControlManager.TestingPrivilege; import io.trino.testing.TestingSession; import org.testng.annotations.Test; +import java.util.Map; import java.util.Optional; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY; @@ -68,6 +74,7 @@ import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestAccessControl @@ -86,6 +93,8 @@ protected QueryRunner createQueryRunner() .build(); queryRunner.installPlugin(new BlackHolePlugin()); queryRunner.createCatalog("blackhole", "blackhole"); + queryRunner.installPlugin(new MemoryPlugin()); + queryRunner.createCatalog("memory", "memory", Map.of()); queryRunner.installPlugin(new TpchPlugin()); queryRunner.createCatalog("tpch", "tpch"); queryRunner.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() @@ -113,6 +122,8 @@ protected QueryRunner createQueryRunner() .withListRoleGrants((connectorSession, roles, grantees, limit) -> ImmutableSet.of(new RoleGrant(new TrinoPrincipal(USER, "alice"), "alice_role", false))) .build())); queryRunner.createCatalog("mock", "mock"); + queryRunner.installPlugin(new JdbcPlugin("base-jdbc", new TestingH2JdbcModule())); + queryRunner.createCatalog("jdbc", "base-jdbc", TestingH2JdbcModule.createProperties()); for (String tableName : ImmutableList.of("orders", "nation", "region", "lineitem")) { queryRunner.execute(format("CREATE TABLE %1$s AS SELECT * FROM tpch.tiny.%1$s WITH NO DATA", tableName)); } @@ -354,6 +365,37 @@ public void testCommentView() .hasMessageContaining("This connector does not support setting view comments"); } + @Test(dataProviderClass = DataProviders.class, dataProvider = "trueFalse") + public void testViewWithTableFunction(boolean securityDefiner) + { + Session viewOwner = getSession(); + Session otherUser = Session.builder(getSession()) + .setIdentity(Identity.ofUser(getSession().getUser() + "-someone-else")) + .build(); + + String viewName = "memory.default.definer_view_with_ptf"; + assertUpdate(viewOwner, "CREATE VIEW " + viewName + " SECURITY " + (securityDefiner ? "DEFINER" : "INVOKER") + " AS SELECT * FROM TABLE (jdbc.system.query('SELECT ''from h2'', monthname(CAST(''2005-09-10'' AS date))'))"); + String viewValues = "VALUES ('from h2', 'September') "; + + assertThat(query(viewOwner, "TABLE " + viewName)).matches(viewValues); + assertThat(query(otherUser, "TABLE " + viewName)).matches(viewValues); + + TestingPrivilege grantExecute = TestingAccessControlManager.privilege("jdbc.system.query", GRANT_EXECUTE_FUNCTION); + assertAccessAllowed(viewOwner, "TABLE " + viewName, grantExecute); + if (securityDefiner) { + assertAccessDenied( + otherUser, + "TABLE " + viewName, + "View owner does not have sufficient privileges: 'user' cannot grant 'jdbc.system.query' execution to user 'user-someone-else'", + grantExecute); + } + else { + assertAccessAllowed(otherUser, "TABLE " + viewName, grantExecute); + } + + assertUpdate("DROP VIEW " + viewName); + } + @Test public void testSetTableProperties() {