diff --git a/.github/DEVELOPMENT.md b/.github/DEVELOPMENT.md index c1518da8c6c5..3da05a0f6b41 100644 --- a/.github/DEVELOPMENT.md +++ b/.github/DEVELOPMENT.md @@ -26,22 +26,26 @@ A typical pull request should strive to contain a single logical change (but not necessarily a single commit). Unrelated changes should generally be extracted into their own PRs. -If a pull request does consist of multiple commits, it is expected that every -prefix of it is correct. That is, there might be preparatory commits at the -bottom of the stack that don't bring any value by themselves, but none of the -commits should introduce an error that is fixed by some future commit. Every -commit should build and pass all tests. - -Commit messages and history are also important, as they are used by other -developers to keep track of the motivation behind changes. Keep logical diffs -grouped together in separate commits, and order commits in a way that explains -the progress of the changes. Rewriting and reordering commits may be a necessary -part of the PR review process as the code changes. Mechanical changes (like -refactoring and renaming)should be separated from logical and functional -changes. E.g. deduplicating code or extracting helper methods should happen in a -separate commit from the commit where new features or behavior is introduced. -This makes reviewing the code much easier and reduces the chance of introducing -unintended changes in behavior. +If a pull request contains a stack of more than one commit, then +popping any number of commits from the top of the stack, should not +break the PR, ie. every commit should build and pass all tests. + +Commit messages and history are important as well, because they are +used by other developers to keep track of the motivation behind +changes. Keep logical diffs grouped together in separate commits and +order commits in a way that explains by itself the evolution of the +change. Rewriting and reordering commits is a natural part of the +review process. Mechanical changes like refactoring, renaming, removing +duplication, extracting helper methods, static imports should be kept +separated from logical and functional changes like adding a new feature +or modifying code behaviour. This makes reviewing the code much easier +and reduces the chance of introducing unintended changes in behavior. + +Whenever in doubt on splitting a change into a separate commit, ask +yourself the following question: if all other work in the PR needs to +be reverted after merging to master for some objective reason (eg. a +bug has been discovered), is it worth keeping that commit still in +master. ## Code Style diff --git a/.github/bin/s3/setup-empty-s3-bucket.sh b/.github/bin/s3/setup-empty-s3-bucket.sh index c3d25e0eb1f0..89a9d7c9939f 100755 --- a/.github/bin/s3/setup-empty-s3-bucket.sh +++ b/.github/bin/s3/setup-empty-s3-bucket.sh @@ -38,6 +38,9 @@ echo "The AWS S3 bucket ${S3_BUCKET_IDENTIFIER} in the region ${AWS_REGION} exis echo "Tagging the AWS S3 bucket ${S3_BUCKET_IDENTIFIER} with TTL tags" +# "test" environment tag is needed so that the bucket gets cleaned up by the daily AWS resource cleanup job in case the +# temporary bucket is not properly cleaned up by delete-s3-bucket.sh. The ttl tag tells the AWS resource cleanup job +# when the bucket is expired and should be cleaned up aws s3api put-bucket-tagging \ --bucket "${S3_BUCKET_IDENTIFIER}" \ - --tagging "TagSet=[{Key=ttl,Value=${S3_BUCKET_TTL}}]" + --tagging "TagSet=[{Key=environment,Value=test},{Key=ttl,Value=${S3_BUCKET_TTL}}]" diff --git a/client/trino-cli/pom.xml b/client/trino-cli/pom.xml index cd343409721b..7808a0bee478 100644 --- a/client/trino-cli/pom.xml +++ b/client/trino-cli/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/client/trino-client/pom.xml b/client/trino-client/pom.xml index a92e41f15417..948e44b7f433 100644 --- a/client/trino-client/pom.xml +++ b/client/trino-client/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -111,6 +111,12 @@ test + + org.junit.jupiter + junit-jupiter-api + test + + org.testng testng diff --git a/client/trino-client/src/main/java/io/trino/client/ClientCapabilities.java b/client/trino-client/src/main/java/io/trino/client/ClientCapabilities.java index 809e2c43e581..ba55f3872ea4 100644 --- a/client/trino-client/src/main/java/io/trino/client/ClientCapabilities.java +++ b/client/trino-client/src/main/java/io/trino/client/ClientCapabilities.java @@ -23,5 +23,7 @@ public enum ClientCapabilities // time(p) without time zone // interval X(p1) to Y(p2) // When this capability is not set, the server returns datetime types with precision = 3 - PARAMETRIC_DATETIME; + PARAMETRIC_DATETIME, + // Whether clients support the session authorization set/reset feature + SESSION_AUTHORIZATION; } diff --git a/client/trino-client/src/test/java/io/trino/client/TestClientTypeSignature.java b/client/trino-client/src/test/java/io/trino/client/TestClientTypeSignature.java index e1cd159dfffc..4a0a681100ce 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestClientTypeSignature.java +++ b/client/trino-client/src/test/java/io/trino/client/TestClientTypeSignature.java @@ -18,7 +18,7 @@ import io.airlift.json.JsonCodecFactory; import io.airlift.json.ObjectMapperProvider; import io.trino.spi.type.StandardTypes; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Optional; diff --git a/client/trino-client/src/test/java/io/trino/client/TestFixJsonDataUtils.java b/client/trino-client/src/test/java/io/trino/client/TestFixJsonDataUtils.java index a231fc5ad056..26c1b8cb7d8e 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestFixJsonDataUtils.java +++ b/client/trino-client/src/test/java/io/trino/client/TestFixJsonDataUtils.java @@ -18,7 +18,7 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.TypeSignatureParameter; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Base64; import java.util.List; diff --git a/client/trino-client/src/test/java/io/trino/client/TestIntervalDayTime.java b/client/trino-client/src/test/java/io/trino/client/TestIntervalDayTime.java index cbe533a5e302..f09b6ef58d51 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestIntervalDayTime.java +++ b/client/trino-client/src/test/java/io/trino/client/TestIntervalDayTime.java @@ -13,7 +13,7 @@ */ package io.trino.client; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.client.IntervalDayTime.formatMillis; import static io.trino.client.IntervalDayTime.parseMillis; diff --git a/client/trino-client/src/test/java/io/trino/client/TestIntervalYearMonth.java b/client/trino-client/src/test/java/io/trino/client/TestIntervalYearMonth.java index 09cd219bd0a3..d67ed4989611 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestIntervalYearMonth.java +++ b/client/trino-client/src/test/java/io/trino/client/TestIntervalYearMonth.java @@ -13,7 +13,7 @@ */ package io.trino.client; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.client.IntervalYearMonth.formatMonths; import static io.trino.client.IntervalYearMonth.parseMonths; diff --git a/client/trino-client/src/test/java/io/trino/client/TestJsonCodec.java b/client/trino-client/src/test/java/io/trino/client/TestJsonCodec.java index 9ecb5f5b13fa..9d62db41db50 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestJsonCodec.java +++ b/client/trino-client/src/test/java/io/trino/client/TestJsonCodec.java @@ -14,7 +14,7 @@ package io.trino.client; import com.fasterxml.jackson.core.JsonParseException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.ByteArrayInputStream; diff --git a/client/trino-client/src/test/java/io/trino/client/TestProtocolHeaders.java b/client/trino-client/src/test/java/io/trino/client/TestProtocolHeaders.java index 12e06a82e834..014e63138b94 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestProtocolHeaders.java +++ b/client/trino-client/src/test/java/io/trino/client/TestProtocolHeaders.java @@ -14,7 +14,7 @@ package io.trino.client; import com.google.common.collect.ImmutableSet; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Optional; diff --git a/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java b/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java index 0245a0178a8f..0f77f0354291 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java +++ b/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java @@ -16,7 +16,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.StreamReadConstraints; import com.google.common.base.Strings; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.client.JsonCodec.jsonCodec; import static java.lang.String.format; diff --git a/client/trino-client/src/test/java/io/trino/client/TestRetry.java b/client/trino-client/src/test/java/io/trino/client/TestRetry.java index 249f8602a3df..852b02a3849a 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestRetry.java +++ b/client/trino-client/src/test/java/io/trino/client/TestRetry.java @@ -20,9 +20,10 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.SocketPolicy; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.io.IOException; import java.net.URI; @@ -42,15 +43,16 @@ import static java.net.HttpURLConnection.HTTP_OK; import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; import static org.testng.Assert.assertTrue; -@Test(singleThreaded = true) +@TestInstance(PER_METHOD) public class TestRetry { private MockWebServer server; private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); - @BeforeMethod(alwaysRun = true) + @BeforeEach public void setup() throws Exception { @@ -58,7 +60,7 @@ public void setup() server.start(); } - @AfterMethod(alwaysRun = true) + @AfterEach public void teardown() throws IOException { diff --git a/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java b/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java index 66390b162f0e..524196d6464b 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java +++ b/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java @@ -15,7 +15,7 @@ import io.airlift.json.JsonCodec; import io.airlift.units.Duration; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Optional; diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java index 9d81a5bf4499..7aa42b8cb95e 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java @@ -14,7 +14,7 @@ package io.trino.client.auth.external; import io.trino.client.ClientException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java index 4615ec31b758..b7daaa378dac 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java @@ -21,8 +21,10 @@ import okhttp3.Response; import org.assertj.core.api.ListAssert; import org.assertj.core.api.ThrowableAssert; -import org.testng.annotations.AfterClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; import java.net.URI; import java.net.URISyntaxException; @@ -51,13 +53,14 @@ import static java.util.concurrent.Executors.newCachedThreadPool; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) public class TestExternalAuthenticator { private static final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(TestExternalAuthenticator.class.getName() + "-%d")); - @AfterClass(alwaysRun = true) + @AfterAll public void shutDownThreadPool() { executor.shutdownNow(); @@ -158,7 +161,8 @@ public void testReAuthenticationAfterRejectingToken() .containsExactly("Bearer second-token"); } - @Test(timeOut = 2000) + @Test + @Timeout(2) public void testAuthenticationFromMultipleThreadsWithLocallyStoredToken() { MockTokenPoller tokenPoller = new MockTokenPoller() @@ -184,7 +188,8 @@ public void testAuthenticationFromMultipleThreadsWithLocallyStoredToken() assertThat(redirectHandler.getRedirectionCount()).isEqualTo(4); } - @Test(timeOut = 2000) + @Test + @Timeout(2) public void testAuthenticationFromMultipleThreadsWithCachedToken() { MockTokenPoller tokenPoller = new MockTokenPoller() @@ -208,7 +213,8 @@ public void testAuthenticationFromMultipleThreadsWithCachedToken() assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); } - @Test(timeOut = 2000) + @Test + @Timeout(2) public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateFails() { MockTokenPoller tokenPoller = new MockTokenPoller() @@ -235,7 +241,8 @@ public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticat assertThat(redirectHandler.getRedirectionCount()).isEqualTo(2); } - @Test(timeOut = 2000) + @Test + @Timeout(2) public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateTimesOut() { MockRedirectHandler redirectHandler = new MockRedirectHandler() @@ -255,7 +262,8 @@ public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticat assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); } - @Test(timeOut = 2000) + @Test + @Timeout(2) public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateIsInterrupted() throws Exception { diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestHttpTokenPoller.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestHttpTokenPoller.java index cc342201acc4..2d1bd130a186 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestHttpTokenPoller.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestHttpTokenPoller.java @@ -18,9 +18,10 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.io.IOException; import java.io.UncheckedIOException; @@ -37,8 +38,9 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; -@Test(singleThreaded = true) +@TestInstance(PER_METHOD) public class TestHttpTokenPoller { private static final String TOKEN_PATH = "/v1/authentications/sso/test/token"; @@ -47,7 +49,7 @@ public class TestHttpTokenPoller private TokenPoller tokenPoller; private MockWebServer server; - @BeforeMethod(alwaysRun = true) + @BeforeEach public void setup() throws Exception { @@ -59,7 +61,7 @@ public void setup() .build()); } - @AfterMethod(alwaysRun = true) + @AfterEach public void teardown() throws IOException { diff --git a/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java b/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java index 0fe6543253e7..860f6be08bc9 100644 --- a/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java +++ b/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java @@ -13,7 +13,7 @@ */ package io.trino.client.uri; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.net.URI; import java.sql.SQLException; @@ -184,11 +184,12 @@ public void testInvalidUrls() "Connection property assumeLiteralNamesInMetadataCallsForNonConformingClients cannot be set if assumeLiteralUnderscoreInMetadataCallsForNonConformingClients is enabled"); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property user value is empty") + @Test public void testEmptyUser() - throws Exception { - TrinoUri.create("trino://localhost:8080?user=", new Properties()); + assertThatThrownBy(() -> TrinoUri.create("trino://localhost:8080?user=", new Properties())) + .isInstanceOf(SQLException.class) + .hasMessage("Connection property user value is empty"); } @Test diff --git a/client/trino-jdbc/pom.xml b/client/trino-jdbc/pom.xml index 1cccb4c772a9..339ced4c64d9 100644 --- a/client/trino-jdbc/pom.xml +++ b/client/trino-jdbc/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-grammar/pom.xml b/core/trino-grammar/pom.xml index a599f2516c33..730d3194d300 100644 --- a/core/trino-grammar/pom.xml +++ b/core/trino-grammar/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml index 5aeb7edb8c46..dafa7b6faab7 100644 --- a/core/trino-main/pom.xml +++ b/core/trino-main/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java index 6bc1b2e66867..8af7bf504227 100644 --- a/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java @@ -16,6 +16,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.inject.Inject; import io.trino.Session; +import io.trino.client.ClientCapabilities; import io.trino.execution.warnings.WarningCollector; import io.trino.spi.TrinoException; import io.trino.sql.tree.Expression; @@ -26,6 +27,7 @@ import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.util.Objects.requireNonNull; public class ResetSessionAuthorizationTask @@ -53,6 +55,9 @@ public ListenableFuture execute( WarningCollector warningCollector) { Session session = stateMachine.getSession(); + if (!session.getClientCapabilities().contains(ClientCapabilities.SESSION_AUTHORIZATION.toString())) { + throw new TrinoException(NOT_SUPPORTED, "RESET SESSION AUTHORIZATION not supported by client"); + } session.getTransactionId().ifPresent(transactionId -> { if (!transactionManager.getTransactionInfo(transactionId).isAutoCommitContext()) { throw new TrinoException(GENERIC_USER_ERROR, "Can't reset authorization user in the middle of a transaction"); diff --git a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java index 99afede9118c..b351549a23ee 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java @@ -16,6 +16,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.inject.Inject; import io.trino.Session; +import io.trino.client.ClientCapabilities; import io.trino.execution.warnings.WarningCollector; import io.trino.security.AccessControl; import io.trino.spi.TrinoException; @@ -31,6 +32,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.util.Objects.requireNonNull; public class SetSessionAuthorizationTask @@ -60,6 +62,9 @@ public ListenableFuture execute( WarningCollector warningCollector) { Session session = stateMachine.getSession(); + if (!session.getClientCapabilities().contains(ClientCapabilities.SESSION_AUTHORIZATION.toString())) { + throw new TrinoException(NOT_SUPPORTED, "SET SESSION AUTHORIZATION not supported by client"); + } Identity originalIdentity = session.getOriginalIdentity(); // Set authorization user in the middle of a transaction is disallowed by the SQL spec session.getTransactionId().ifPresent(transactionId -> { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/NoMemoryAwarePartitionMemoryEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/NoMemoryAwarePartitionMemoryEstimator.java index bceb6b20159d..2703acc2ce09 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/NoMemoryAwarePartitionMemoryEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/NoMemoryAwarePartitionMemoryEstimator.java @@ -95,9 +95,7 @@ private boolean isNoMemoryFragment(PlanFragment fragment, Function 0, "positionCount should be > 0, but is %s", positionCount); ColumnarRow mergeRow = toColumnarRow(inputPage.getBlock(mergeRowChannel)); - checkArgument(!mergeRow.mayHaveNull(), "The mergeRow may not have null rows"); + if (mergeRow.mayHaveNull()) { + for (int position = 0; position < positionCount; position++) { + checkArgument(!mergeRow.isNull(position), "The mergeRow may not have null rows"); + } + } // We've verified that the mergeRow block has no null rows, so it's okay to get the field blocks diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java index cdb85915dd01..14e8fb6d42ad 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java @@ -58,6 +58,8 @@ public class FlatGroupByHash // reusable arrays for the blocks and block builders private final Block[] currentBlocks; private final BlockBuilder[] currentBlockBuilders; + // reusable array for computing hash batches into + private long[] currentHashes; public FlatGroupByHash( List hashTypes, @@ -98,6 +100,7 @@ public long getEstimatedSize() INSTANCE_SIZE, flatHash.getEstimatedSize(), currentPageSizeInBytes, + sizeOf(currentHashes), (dictionaryLookBack != null ? dictionaryLookBack.getRetainedSizeInBytes() : 0)); } @@ -175,6 +178,14 @@ private int putIfAbsent(Block[] blocks, int position) return flatHash.putIfAbsent(blocks, position); } + private long[] getHashesBufferArray() + { + if (currentHashes == null) { + currentHashes = new long[BATCH_SIZE]; + } + return currentHashes; + } + private Block[] getBlocksFromPage(Page page) { Block[] blocks = currentBlocks; @@ -308,14 +319,16 @@ public boolean process() int remainingPositions = positionCount - lastPosition; + long[] hashes = getHashesBufferArray(); while (remainingPositions != 0) { - int batchSize = min(remainingPositions, BATCH_SIZE); + int batchSize = min(remainingPositions, hashes.length); if (!flatHash.ensureAvailableCapacity(batchSize)) { return false; } - for (int i = lastPosition; i < lastPosition + batchSize; i++) { - putIfAbsent(blocks, i); + flatHash.computeHashes(blocks, hashes, lastPosition, batchSize); + for (int i = 0; i < batchSize; i++) { + flatHash.putIfAbsent(blocks, lastPosition + i, hashes[i]); } lastPosition += batchSize; @@ -473,14 +486,16 @@ public boolean process() int remainingPositions = positionCount - lastPosition; + long[] hashes = getHashesBufferArray(); while (remainingPositions != 0) { - int batchSize = min(remainingPositions, BATCH_SIZE); + int batchSize = min(remainingPositions, hashes.length); if (!flatHash.ensureAvailableCapacity(batchSize)) { return false; } - for (int i = lastPosition; i < lastPosition + batchSize; i++) { - groupIds[i] = putIfAbsent(blocks, i); + flatHash.computeHashes(blocks, hashes, lastPosition, batchSize); + for (int i = 0, position = lastPosition; i < batchSize; i++, position++) { + groupIds[position] = flatHash.putIfAbsent(blocks, position, hashes[i]); } lastPosition += batchSize; diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHash.java b/core/trino-main/src/main/java/io/trino/operator/FlatHash.java index c7c0360946db..3086d9594032 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHash.java @@ -174,16 +174,21 @@ public boolean contains(Block[] blocks, int position, long hash) return getIndex(blocks, position, hash) >= 0; } - public int putIfAbsent(Block[] blocks, int position) + public void computeHashes(Block[] blocks, long[] hashes, int offset, int length) { - long hash; if (hasPrecomputedHash) { - hash = BIGINT.getLong(blocks[blocks.length - 1], position); + Block hashBlock = blocks[blocks.length - 1]; + for (int i = 0; i < length; i++) { + hashes[i] = BIGINT.getLong(hashBlock, offset + i); + } } else { - hash = flatHashStrategy.hash(blocks, position); + flatHashStrategy.hashBlocksBatched(blocks, hashes, offset, length); } + } + public int putIfAbsent(Block[] blocks, int position, long hash) + { int index = getIndex(blocks, position, hash); if (index >= 0) { return (int) INT_HANDLE.get(getRecords(index), getRecordOffset(index) + recordGroupIdOffset); @@ -197,6 +202,19 @@ public int putIfAbsent(Block[] blocks, int position) return groupId; } + public int putIfAbsent(Block[] blocks, int position) + { + long hash; + if (hasPrecomputedHash) { + hash = BIGINT.getLong(blocks[blocks.length - 1], position); + } + else { + hash = flatHashStrategy.hash(blocks, position); + } + + return putIfAbsent(blocks, position, hash); + } + private int getIndex(Block[] blocks, int position, long hash) { byte hashPrefix = (byte) (hash & 0x7F | 0x80); diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java index c4ddeee500af..7213d4fbc5e9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java @@ -38,4 +38,6 @@ boolean valueNotDistinctFrom( long hash(Block[] blocks, int position); long hash(byte[] fixedChunk, int fixedOffset, byte[] variableChunk); + + void hashBlocksBatched(Block[] blocks, long[] hashes, int offset, int length); } diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java index 6097ea2add20..1919c222ba91 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -21,6 +21,7 @@ import io.airlift.bytecode.Parameter; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; import io.trino.operator.scalar.CombineHashFunction; import io.trino.spi.block.Block; @@ -31,7 +32,10 @@ import java.lang.invoke.MethodHandle; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Objects; import static io.airlift.bytecode.Access.FINAL; import static io.airlift.bytecode.Access.PRIVATE; @@ -41,6 +45,7 @@ import static io.airlift.bytecode.Parameter.arg; import static io.airlift.bytecode.ParameterizedType.type; import static io.airlift.bytecode.expression.BytecodeExpressions.add; +import static io.airlift.bytecode.expression.BytecodeExpressions.and; import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; @@ -48,6 +53,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; import static io.airlift.bytecode.expression.BytecodeExpressions.not; import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; @@ -120,6 +126,7 @@ public static FlatHashStrategy compileFlatHashStrategy(List types, TypeOpe generateNotDistinctFromMethod(definition, keyFields, callSiteBinder); generateHashBlock(definition, keyFields, callSiteBinder); generateHashFlat(definition, keyFields, callSiteBinder); + generateHashBlocksBatched(definition, keyFields, callSiteBinder); try { return defineClass(definition, FlatHashStrategy.class, callSiteBinder.getBindings(), FlatHashStrategyCompiler.class.getClassLoader()) @@ -352,6 +359,103 @@ private static void generateHashBlock(ClassDefinition definition, List body.append(result.ret()); } + private static void generateHashBlocksBatched(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter hashes = arg("hashes", type(long[].class)); + Parameter offset = arg("offset", type(int.class)); + Parameter length = arg("length", type(int.class)); + + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "hashBlocksBatched", + type(void.class), + blocks, + hashes, + offset, + length); + + BytecodeBlock body = methodDefinition.getBody(); + body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); + + Map typeMethods = new HashMap<>(); + for (KeyField keyField : keyFields) { + MethodDefinition method; + // First hash method implementation does not combine hashes, so it can't be reused + if (keyField.index() == 0) { + method = generateHashBlockVectorized(definition, keyField, callSiteBinder); + } + else { + // Columns of the same type can reuse the same static method implementation + method = typeMethods.get(keyField.type()); + if (method == null) { + method = generateHashBlockVectorized(definition, keyField, callSiteBinder); + typeMethods.put(keyField.type(), method); + } + } + body.append(invokeStatic(method, blocks.getElement(keyField.index()), hashes, offset, length)); + } + body.ret(); + } + + private static MethodDefinition generateHashBlockVectorized(ClassDefinition definition, KeyField field, CallSiteBinder callSiteBinder) + { + Parameter block = arg("block", type(Block.class)); + Parameter hashes = arg("hashes", type(long[].class)); + Parameter offset = arg("offset", type(int.class)); + Parameter length = arg("length", type(int.class)); + + MethodDefinition methodDefinition = definition.declareMethod( + a(PRIVATE, STATIC), + "hashBlockVectorized_" + field.index(), + type(void.class), + block, + hashes, + offset, + length); + + Scope scope = methodDefinition.getScope(); + BytecodeBlock body = methodDefinition.getBody(); + + Variable index = scope.declareVariable(int.class, "index"); + Variable position = scope.declareVariable(int.class, "position"); + Variable mayHaveNull = scope.declareVariable(boolean.class, "mayHaveNull"); + Variable hash = scope.declareVariable(long.class, "hash"); + + body.append(mayHaveNull.set(block.invoke("mayHaveNull", boolean.class))); + body.append(position.set(invokeStatic(Objects.class, "checkFromToIndex", int.class, offset, add(offset, length), block.invoke("getPositionCount", int.class)))); + body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); + + BytecodeBlock loopBody = new BytecodeBlock().append(new IfStatement("if (mayHaveNull && block.isNull(position))") + .condition(and(mayHaveNull, block.invoke("isNull", boolean.class, position))) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(field.hashBlockMethod()).getBindingId()), + "hash", + long.class, + block, + position)))); + if (field.index() == 0) { + // hashes[index] = hash; + loopBody.append(hashes.setElement(index, hash)); + } + else { + // hashes[index] = CombineHashFunction.getHash(hashes[index], hash); + loopBody.append(hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash))); + } + loopBody.append(position.increment()); + + body.append(new ForLoop("for (index = 0; index < length; index++)") + .initialize(index.set(constantInt(0))) + .condition(lessThan(index, length)) + .update(index.increment()) + .body(loopBody)) + .ret(); + + return methodDefinition; + } + private static void generateHashFlat(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) { Parameter fixedChunk = arg("fixedChunk", type(byte[].class)); diff --git a/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java b/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java index c847fbe6facd..77a3f5789ab0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java @@ -18,6 +18,9 @@ import io.airlift.units.DataSize; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.Type; @@ -32,8 +35,8 @@ import static io.trino.operator.VariableWidthData.EMPTY_CHUNK; import static io.trino.operator.VariableWidthData.POINTER_SIZE; import static io.trino.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; @@ -88,8 +91,8 @@ public class JoinDomainBuilder private int distinctSize; private int distinctMaxFill; - private Block minValue; - private Block maxValue; + private ValueBlock minValue; + private ValueBlock maxValue; private boolean collectDistinctValues = true; private boolean collectMinMax; @@ -116,15 +119,15 @@ public JoinDomainBuilder( MethodHandle readOperator = typeOperators.getReadValueOperator(type, simpleConvention(NULLABLE_RETURN, FLAT)); readOperator = readOperator.asType(readOperator.type().changeReturnType(Object.class)); this.readFlat = readOperator; - this.writeFlat = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + this.writeFlat = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); this.hashFlat = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); - this.hashBlock = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + this.hashBlock = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); this.distinctFlatFlat = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - this.distinctFlatBlock = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + this.distinctFlatBlock = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); if (collectMinMax) { this.compareFlatFlat = typeOperators.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - this.compareBlockBlock = typeOperators.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL)); + this.compareBlockBlock = typeOperators.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); } else { this.compareFlatFlat = null; @@ -157,9 +160,24 @@ public boolean isCollecting() public void add(Block block) { + block = block.getLoadedBlock(); if (collectDistinctValues) { - for (int position = 0; position < block.getPositionCount(); ++position) { - add(block, position); + if (block instanceof ValueBlock valueBlock) { + for (int position = 0; position < block.getPositionCount(); position++) { + add(valueBlock, position); + } + } + else if (block instanceof RunLengthEncodedBlock rleBlock) { + add(rleBlock.getValue(), 0); + } + else if (block instanceof DictionaryBlock dictionaryBlock) { + ValueBlock dictionary = dictionaryBlock.getDictionary(); + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + add(dictionary, dictionaryBlock.getId(i)); + } + } + else { + throw new IllegalArgumentException("Unsupported block type: " + block.getClass().getSimpleName()); } // if the distinct size is too large, fall back to min max, and drop the distinct values @@ -207,8 +225,10 @@ else if (collectMinMax) { int minValuePosition = -1; int maxValuePosition = -1; - for (int position = 0; position < block.getPositionCount(); ++position) { - if (block.isNull(position)) { + ValueBlock valueBlock = block.getUnderlyingValueBlock(); + for (int i = 0; i < block.getPositionCount(); i++) { + int position = block.getUnderlyingValuePosition(i); + if (valueBlock.isNull(position)) { continue; } if (minValuePosition == -1) { @@ -217,10 +237,10 @@ else if (collectMinMax) { maxValuePosition = position; continue; } - if (valueCompare(block, position, block, minValuePosition) < 0) { + if (valueCompare(valueBlock, position, valueBlock, minValuePosition) < 0) { minValuePosition = position; } - else if (valueCompare(block, position, block, maxValuePosition) > 0) { + else if (valueCompare(valueBlock, position, valueBlock, maxValuePosition) > 0) { maxValuePosition = position; } } @@ -231,18 +251,18 @@ else if (valueCompare(block, position, block, maxValuePosition) > 0) { } if (minValue == null) { - minValue = block.getSingleValueBlock(minValuePosition); - maxValue = block.getSingleValueBlock(maxValuePosition); + minValue = valueBlock.getSingleValueBlock(minValuePosition); + maxValue = valueBlock.getSingleValueBlock(maxValuePosition); return; } - if (valueCompare(block, minValuePosition, minValue, 0) < 0) { + if (valueCompare(valueBlock, minValuePosition, minValue, 0) < 0) { retainedSizeInBytes -= minValue.getRetainedSizeInBytes(); - minValue = block.getSingleValueBlock(minValuePosition); + minValue = valueBlock.getSingleValueBlock(minValuePosition); retainedSizeInBytes += minValue.getRetainedSizeInBytes(); } - if (valueCompare(block, maxValuePosition, maxValue, 0) > 0) { + if (valueCompare(valueBlock, maxValuePosition, maxValue, 0) > 0) { retainedSizeInBytes -= maxValue.getRetainedSizeInBytes(); - maxValue = block.getSingleValueBlock(maxValuePosition); + maxValue = valueBlock.getSingleValueBlock(maxValuePosition); retainedSizeInBytes += maxValue.getRetainedSizeInBytes(); } } @@ -289,7 +309,7 @@ public Domain build() return Domain.all(type); } - private void add(Block block, int position) + private void add(ValueBlock block, int position) { // Inner and right join doesn't match rows with null key column values. if (block.isNull(position)) { @@ -343,7 +363,7 @@ private int matchInVector(byte[] otherValues, VariableWidthData otherVariableWid return -1; } - private int matchInVector(Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -367,7 +387,7 @@ private int findEmptyInVector(long vector, int vectorStartBucket) return bucket(vectorStartBucket + slot); } - private void insert(int index, Block block, int position, byte hashPrefix) + private void insert(int index, ValueBlock block, int position, byte hashPrefix) { setControl(index, hashPrefix); @@ -512,7 +532,7 @@ private Object readValueToObject(int position) } } - private Block readValueToBlock(int position) + private ValueBlock readValueToBlock(int position) { return writeNativeValue(type, readValueToObject(position)); } @@ -538,7 +558,7 @@ private long valueHashCode(byte[] values, int position) } } - private long valueHashCode(Block right, int rightPosition) + private long valueHashCode(ValueBlock right, int rightPosition) { try { return (long) hashBlock.invokeExact(right, rightPosition); @@ -549,7 +569,7 @@ private long valueHashCode(Block right, int rightPosition) } } - private boolean valueNotDistinctFrom(int leftPosition, Block right, int rightPosition) + private boolean valueNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition) { byte[] leftFixedRecordChunk = distinctRecords; int leftRecordOffset = getRecordOffset(leftPosition); @@ -603,7 +623,7 @@ private boolean valueNotDistinctFrom(int leftPosition, byte[] rightValues, Varia } } - private int valueCompare(Block left, int leftPosition, Block right, int rightPosition) + private int valueCompare(ValueBlock left, int leftPosition, ValueBlock right, int rightPosition) { try { return (int) (long) compareBlockBlock.invokeExact( diff --git a/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java index 6f9768affeb8..ef1627843f86 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java @@ -17,7 +17,7 @@ public interface PartitionFunction { - int getPartitionCount(); + int partitionCount(); /** * @param page the arguments to bucketing function in order (no extra columns) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java index 02c3091a6cb5..e410366b95db 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java @@ -17,9 +17,9 @@ import com.google.common.primitives.Ints; import io.trino.operator.VariableWidthData; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -285,7 +285,7 @@ private void serializeEntry(BlockBuilder keyBuilder, BlockBuilder valueBuilder, } } - protected void add(int groupId, Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + protected void add(int groupId, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { checkArgument(!keyBlock.isNull(keyPosition), "key must not be null"); checkArgument(groupId == 0 || groupRecordIndex != null, "groupId must be zero when grouping is not enabled"); @@ -322,7 +322,7 @@ protected void add(int groupId, Block keyBlock, int keyPosition, Block valueBloc } } - private int matchInVector(int groupId, Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(int groupId, ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -346,7 +346,7 @@ private int findEmptyInVector(long vector, int vectorStartBucket) return bucket(vectorStartBucket + slot); } - private void insert(int index, int groupId, Block keyBlock, int keyPosition, Block valueBlock, int valuePosition, byte hashPrefix) + private void insert(int index, int groupId, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition, byte hashPrefix) { setControl(index, hashPrefix); @@ -499,7 +499,7 @@ private long keyHashCode(int groupId, byte[] records, int index) } } - private long keyHashCode(int groupId, Block right, int rightPosition) + private long keyHashCode(int groupId, ValueBlock right, int rightPosition) { try { long valueHash = (long) keyHashBlock.invokeExact(right, rightPosition); @@ -511,7 +511,7 @@ private long keyHashCode(int groupId, Block right, int rightPosition) } } - private boolean keyNotDistinctFrom(int leftPosition, Block right, int rightPosition, int rightGroupId) + private boolean keyNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition, int rightGroupId) { byte[] leftRecords = getRecords(leftPosition); int leftRecordOffset = getRecordOffset(leftPosition); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 2333a73b36a4..98b417aea187 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -33,6 +33,7 @@ import io.trino.spi.block.ColumnarRow; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.RowValueBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateSerializer; @@ -71,6 +72,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; +import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop; import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder; import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD; import static io.trino.sql.gen.BytecodeUtils.invoke; @@ -88,7 +90,8 @@ private AccumulatorCompiler() {} public static AccumulatorFactory generateAccumulatorFactory( BoundSignature boundSignature, AggregationImplementation implementation, - FunctionNullability functionNullability) + FunctionNullability functionNullability, + boolean specializedLoops) { // change types used in Aggregation methods to types used in the core Trino engine to simplify code generation implementation = normalizeAggregationMethods(implementation); @@ -98,19 +101,21 @@ public static AccumulatorFactory generateAccumulatorFactory( List argumentNullable = functionNullability.getArgumentNullable() .subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size()); - Constructor accumulatorConstructor = generateAccumulatorClass( + Constructor groupedAccumulatorConstructor = generateAccumulatorClass( boundSignature, - Accumulator.class, + GroupedAccumulator.class, implementation, argumentNullable, - classLoader); + classLoader, + specializedLoops); - Constructor groupedAccumulatorConstructor = generateAccumulatorClass( + Constructor accumulatorConstructor = generateAccumulatorClass( boundSignature, - GroupedAccumulator.class, + Accumulator.class, implementation, argumentNullable, - classLoader); + classLoader, + specializedLoops); List nonNullArguments = new ArrayList<>(); for (int argumentIndex = 0; argumentIndex < argumentNullable.size(); argumentIndex++) { @@ -132,7 +137,8 @@ private static Constructor generateAccumulatorClass( Class accumulatorInterface, AggregationImplementation implementation, List argumentNullable, - DynamicClassLoader classLoader) + DynamicClassLoader classLoader, + boolean specializedLoops) { boolean grouped = accumulatorInterface == GroupedAccumulator.class; @@ -180,6 +186,7 @@ private static Constructor generateAccumulatorClass( generateAddInput( definition, + specializedLoops, stateFields, argumentNullable, lambdaProviderFields, @@ -363,6 +370,7 @@ private static void generateSetGroupCount(ClassDefinition definition, List stateField, List argumentNullable, List lambdaProviderFields, @@ -395,6 +403,7 @@ private static void generateAddInput( } BytecodeBlock block = generateInputForLoop( + specializedLoops, stateField, inputFunction, scope, @@ -429,25 +438,40 @@ private static void generateAddOrRemoveInputWindowIndex( type(void.class), ImmutableList.of(index, startPosition, endPosition)); Scope scope = method.getScope(); + BytecodeBlock body = method.getBody(); Variable position = scope.declareVariable(int.class, "position"); + // input parameters + Variable inputBlockPosition = scope.declareVariable(int.class, "inputBlockPosition"); + List inputBlockVariables = new ArrayList<>(); + for (int i = 0; i < argumentNullable.size(); i++) { + inputBlockVariables.add(scope.declareVariable(Block.class, "inputBlock" + i)); + } + Binding binding = callSiteBinder.bind(inputFunction); - BytecodeExpression invokeInputFunction = invokeDynamic( + BytecodeBlock invokeInputFunction = new BytecodeBlock(); + // WindowIndex is built on PagesIndex, which simply wraps Blocks + // and currently does not understand ValueBlocks. + // Until PagesIndex is updated to understand ValueBlocks, the + // input function parameters must be directly unwrapped to ValueBlocks. + invokeInputFunction.append(inputBlockPosition.set(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position))); + for (int i = 0; i < inputBlockVariables.size(); i++) { + invokeInputFunction.append(inputBlockVariables.get(i).set(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position))); + } + invokeInputFunction.append(invokeDynamic( BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), generatedFunctionName, binding.getType(), getInvokeFunctionOnWindowIndexParameters( - scope, - argumentNullable.size(), - lambdaProviderFields, + scope.getThis(), stateField, - index, - position)); + inputBlockPosition, + inputBlockVariables, + lambdaProviderFields))); - method.getBody() - .append(new ForLoop() + body.append(new ForLoop() .initialize(position.set(startPosition)) .condition(BytecodeExpressions.lessThanOrEqual(position, endPosition)) .update(position.increment()) @@ -473,33 +497,28 @@ private static BytecodeExpression anyParametersAreNull( } private static List getInvokeFunctionOnWindowIndexParameters( - Scope scope, - int inputParameterCount, - List lambdaProviderFields, + Variable thisVariable, List stateField, - Variable index, - Variable position) + Variable inputBlockPosition, + List inputBlockVariables, + List lambdaProviderFields) { List expressions = new ArrayList<>(); // state parameters for (FieldDefinition field : stateField) { - expressions.add(scope.getThis().getField(field)); + expressions.add(thisVariable.getField(field)); } // input parameters - for (int i = 0; i < inputParameterCount; i++) { - expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position)); - } - - // position parameter - if (inputParameterCount > 0) { - expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position)); + for (Variable blockVariable : inputBlockVariables) { + expressions.add(blockVariable.invoke("getUnderlyingValueBlock", ValueBlock.class)); + expressions.add(blockVariable.invoke("getUnderlyingValuePosition", int.class, inputBlockPosition)); } // lambda parameters for (FieldDefinition lambdaProviderField : lambdaProviderFields) { - expressions.add(scope.getThis().getField(lambdaProviderField) + expressions.add(thisVariable.getField(lambdaProviderField) .invoke("get", Object.class)); } @@ -507,6 +526,7 @@ private static List getInvokeFunctionOnWindowIndexParameters } private static BytecodeBlock generateInputForLoop( + boolean specializedLoops, List stateField, MethodHandle inputFunction, Scope scope, @@ -516,6 +536,30 @@ private static BytecodeBlock generateInputForLoop( CallSiteBinder callSiteBinder, boolean grouped) { + if (specializedLoops) { + BytecodeBlock newBlock = new BytecodeBlock(); + Variable thisVariable = scope.getThis(); + + MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterVariables.size(), grouped); + + ImmutableList.Builder parameters = ImmutableList.builder(); + parameters.add(mask); + if (grouped) { + parameters.add(scope.getVariable("groupIds")); + } + for (FieldDefinition fieldDefinition : stateField) { + parameters.add(thisVariable.getField(fieldDefinition)); + } + parameters.addAll(parameterVariables); + for (FieldDefinition lambdaProviderField : lambdaProviderFields) { + parameters.add(scope.getThis().getField(lambdaProviderField) + .invoke("get", Object.class)); + } + + newBlock.append(invoke(callSiteBinder.bind(mainLoop), "mainLoop", parameters.build())); + return newBlock; + } + // For-loop over rows Variable positionVariable = scope.declareVariable(int.class, "position"); Variable rowsVariable = scope.declareVariable(int.class, "rows"); @@ -596,11 +640,9 @@ private static BytecodeBlock generateInvokeInputFunction( } // input parameters - parameters.addAll(parameterVariables); - - // position parameter - if (!parameterVariables.isEmpty()) { - parameters.add(position); + for (Variable variable : parameterVariables) { + parameters.add(variable.invoke("getUnderlyingValueBlock", ValueBlock.class)); + parameters.add(variable.invoke("getUnderlyingValuePosition", int.class, position)); } // lambda parameters @@ -1054,32 +1096,38 @@ private static BytecodeExpression generateRequireNotNull(BytecodeExpression expr private static AggregationImplementation normalizeAggregationMethods(AggregationImplementation implementation) { // change aggregations state variables to simply AccumulatorState to avoid any class loader issues in generated code - int stateParameterCount = implementation.getAccumulatorStateDescriptors().size(); int lambdaParameterCount = implementation.getLambdaInterfaces().size(); AggregationImplementation.Builder builder = AggregationImplementation.builder(); - builder.inputFunction(castStateParameters(implementation.getInputFunction(), stateParameterCount, lambdaParameterCount)); + builder.inputFunction(normalizeParameters(implementation.getInputFunction(), lambdaParameterCount)); implementation.getRemoveInputFunction() - .map(removeFunction -> castStateParameters(removeFunction, stateParameterCount, lambdaParameterCount)) + .map(removeFunction -> normalizeParameters(removeFunction, lambdaParameterCount)) .ifPresent(builder::removeInputFunction); implementation.getCombineFunction() - .map(combineFunction -> castStateParameters(combineFunction, stateParameterCount * 2, lambdaParameterCount)) + .map(combineFunction -> normalizeParameters(combineFunction, lambdaParameterCount)) .ifPresent(builder::combineFunction); - builder.outputFunction(castStateParameters(implementation.getOutputFunction(), stateParameterCount, 0)); + builder.outputFunction(normalizeParameters(implementation.getOutputFunction(), 0)); builder.accumulatorStateDescriptors(implementation.getAccumulatorStateDescriptors()); builder.lambdaInterfaces(implementation.getLambdaInterfaces()); return builder.build(); } - private static MethodHandle castStateParameters(MethodHandle inputFunction, int stateParameterCount, int lambdaParameterCount) + private static MethodHandle normalizeParameters(MethodHandle function, int lambdaParameterCount) { - Class[] parameterTypes = inputFunction.type().parameterArray(); - for (int i = 0; i < stateParameterCount; i++) { - parameterTypes[i] = AccumulatorState.class; + Class[] parameterTypes = function.type().parameterArray(); + for (int i = 0; i < parameterTypes.length; i++) { + Class parameterType = parameterTypes[i]; + if (AccumulatorState.class.isAssignableFrom(parameterType)) { + parameterTypes[i] = AccumulatorState.class; + } + else if (ValueBlock.class.isAssignableFrom(parameterType)) { + parameterTypes[i] = ValueBlock.class; + } } for (int i = parameterTypes.length - lambdaParameterCount; i < parameterTypes.length; i++) { parameterTypes[i] = Object.class; } - return MethodHandles.explicitCastArguments(inputFunction, MethodType.methodType(inputFunction.type().returnType(), parameterTypes)); + MethodType newType = MethodType.methodType(function.type().returnType(), parameterTypes); + return MethodHandles.explicitCastArguments(function, newType); } private static class StateFieldAndDescriptor diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index f27391dbefd5..a46cd8bcaa73 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -280,7 +280,7 @@ private static List getInputFunctions(Class clazz, List 1) { List> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction) .subList(0, stateDetails.size()); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java index 84d20bfddf86..6315b354cdd7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java @@ -15,15 +15,13 @@ import com.google.common.collect.ImmutableList; import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.BoundSignature; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -34,7 +32,7 @@ import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static java.lang.invoke.MethodHandles.collectArguments; import static java.lang.invoke.MethodHandles.lookup; -import static java.lang.invoke.MethodHandles.permuteArguments; +import static java.lang.invoke.MethodType.methodType; import static java.util.Objects.requireNonNull; public final class AggregationFunctionAdapter @@ -55,10 +53,14 @@ public enum AggregationParameterKind static { try { - BOOLEAN_TYPE_GETTER = lookup().findVirtual(Type.class, "getBoolean", MethodType.methodType(boolean.class, Block.class, int.class)); - LONG_TYPE_GETTER = lookup().findVirtual(Type.class, "getLong", MethodType.methodType(long.class, Block.class, int.class)); - DOUBLE_TYPE_GETTER = lookup().findVirtual(Type.class, "getDouble", MethodType.methodType(double.class, Block.class, int.class)); - OBJECT_TYPE_GETTER = lookup().findVirtual(Type.class, "getObject", MethodType.methodType(Object.class, Block.class, int.class)); + BOOLEAN_TYPE_GETTER = lookup().findVirtual(Type.class, "getBoolean", methodType(boolean.class, Block.class, int.class)) + .asType(methodType(boolean.class, Type.class, ValueBlock.class, int.class)); + LONG_TYPE_GETTER = lookup().findVirtual(Type.class, "getLong", methodType(long.class, Block.class, int.class)) + .asType(methodType(long.class, Type.class, ValueBlock.class, int.class)); + DOUBLE_TYPE_GETTER = lookup().findVirtual(Type.class, "getDouble", methodType(double.class, Block.class, int.class)) + .asType(methodType(double.class, Type.class, ValueBlock.class, int.class)); + OBJECT_TYPE_GETTER = lookup().findVirtual(Type.class, "getObject", methodType(Object.class, Block.class, int.class)) + .asType(methodType(Object.class, Type.class, ValueBlock.class, int.class)); } catch (ReflectiveOperationException e) { throw new AssertionError(e); @@ -103,7 +105,6 @@ public static MethodHandle normalizeInputMethod( List inputArgumentKinds = parameterKinds.stream() .filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) .collect(toImmutableList()); - boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL); checkArgument( boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), @@ -113,21 +114,26 @@ public static MethodHandle normalizeInputMethod( List expectedInputArgumentKinds = new ArrayList<>(); expectedInputArgumentKinds.addAll(stateArgumentKinds); - expectedInputArgumentKinds.addAll(inputArgumentKinds); - if (hasInputChannel) { - expectedInputArgumentKinds.add(BLOCK_INDEX); + for (AggregationParameterKind kind : inputArgumentKinds) { + expectedInputArgumentKinds.add(kind); + if (kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) { + expectedInputArgumentKinds.add(BLOCK_INDEX); + } } + checkArgument( expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds); - MethodType inputMethodType = inputMethod.type(); for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) { - int parameterIndex = stateArgumentKinds.size() + argumentIndex; + int parameterIndex = stateArgumentKinds.size() + (argumentIndex * 2); AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex); if (inputArgument != INPUT_CHANNEL) { + if (inputArgument == BLOCK_INPUT_CHANNEL || inputArgument == NULLABLE_BLOCK_INPUT_CHANNEL) { + checkArgument(ValueBlock.class.isAssignableFrom(inputMethod.type().parameterType(parameterIndex)), "Expected parameter %s to be a ValueBlock", parameterIndex); + } continue; } Type argumentType = boundSignature.getArgumentType(argumentIndex); @@ -145,27 +151,9 @@ else if (argumentType.getJavaType().equals(double.class)) { } else { valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType); - valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex))); + valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethod.type().parameterType(parameterIndex))); } inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter); - - // move the position argument to the end (and combine with other existing position argument) - inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class); - - ArrayList reorder; - if (hasInputChannel) { - reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new)); - reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount); - } - else { - inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class); - reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new)); - int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount; - reorder.remove(positionParameterIndex); - reorder.add(parameterIndex + 1, positionParameterIndex); - hasInputChannel = true; - } - inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray()); } return inputMethod; } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java new file mode 100644 index 000000000000..e7b7dd678452 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java @@ -0,0 +1,331 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.ForLoop; +import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.expression.BytecodeExpression; +import io.airlift.bytecode.expression.BytecodeExpressions; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.function.GroupedAccumulatorState; +import io.trino.sql.gen.CallSiteBinder; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodType; +import java.lang.reflect.Method; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterables.cycle; +import static com.google.common.collect.Iterables.limit; +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PRIVATE; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.STATIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantString; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; +import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; +import static io.trino.sql.gen.BytecodeUtils.invoke; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; + +final class AggregationLoopBuilder +{ + private AggregationLoopBuilder() {} + + /** + * Build a loop over the aggregation function. Internally, there are multiple loops generated that are specialized for + * RLE, Dictionary, and basic blocks, and for masked or unmasked input. The method handle is expected to have a {@link Block} and int + * position argument for each parameter. The returned method handle signature, will start with as {@link AggregationMask} + * and then a single {@link Block} for each parameter. + */ + public static MethodHandle buildLoop(MethodHandle function, int stateCount, int parameterCount, boolean grouped) + { + verifyFunctionSignature(function, stateCount, parameterCount); + CallSiteBinder binder = new CallSiteBinder(); + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, STATIC, FINAL), + makeClassName("AggregationLoop"), + type(Object.class)); + + definition.declareDefaultConstructor(a(PRIVATE)); + + buildSpecializedLoop(binder, definition, function, stateCount, parameterCount, grouped); + + Class clazz = defineClass(definition, Object.class, binder.getBindings(), AggregationLoopBuilder.class.getClassLoader()); + + // it is simpler to find the method with reflection than using lookup().findStatic because of the complex signature + Method invokeMethod = Arrays.stream(clazz.getMethods()) + .filter(method -> method.getName().equals("invoke")) + .collect(onlyElement()); + + try { + return lookup().unreflect(invokeMethod); + } + catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static void buildSpecializedLoop(CallSiteBinder binder, ClassDefinition classDefinition, MethodHandle function, int stateCount, int parameterCount, boolean grouped) + { + AggregationParameters aggregationParameters = AggregationParameters.create(function, stateCount, parameterCount, grouped); + MethodDefinition methodDefinition = classDefinition.declareMethod( + a(PUBLIC, STATIC), + "invoke", + type(void.class), + aggregationParameters.allParameters()); + + Function, BytecodeNode> coreLoopBuilder = (blockTypes) -> { + MethodDefinition method = buildCoreLoop(binder, classDefinition, function, blockTypes, aggregationParameters); + return invokeStatic(method, aggregationParameters.allParameters().toArray(new BytecodeExpression[0])); + }; + + BytecodeNode bytecodeNode = buildLoopSelection(coreLoopBuilder, new ArrayDeque<>(parameterCount), new ArrayDeque<>(aggregationParameters.blocks())); + methodDefinition.getBody() + .append(bytecodeNode) + .ret(); + } + + private static BytecodeNode buildLoopSelection(Function, BytecodeNode> coreLoopBuilder, ArrayDeque currentTypes, ArrayDeque remainingParameters) + { + if (remainingParameters.isEmpty()) { + return coreLoopBuilder.apply(ImmutableList.copyOf(currentTypes)); + } + + // remove the next parameter from the queue + Parameter blockParameter = remainingParameters.removeFirst(); + + currentTypes.addLast(BlockType.VALUE); + BytecodeNode valueLoop = buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters); + currentTypes.removeLast(); + + currentTypes.addLast(BlockType.DICTIONARY); + BytecodeNode dictionaryLoop = buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters); + currentTypes.removeLast(); + + currentTypes.addLast(BlockType.RLE); + BytecodeNode rleLoop = buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters); + currentTypes.removeLast(); + + IfStatement blockTypeSelection = new IfStatement() + .condition(blockParameter.instanceOf(ValueBlock.class)) + .ifTrue(valueLoop) + .ifFalse(new IfStatement() + .condition(blockParameter.instanceOf(DictionaryBlock.class)) + .ifTrue(dictionaryLoop) + .ifFalse(new IfStatement() + .condition(blockParameter.instanceOf(RunLengthEncodedBlock.class)) + .ifTrue(rleLoop) + .ifFalse(new BytecodeBlock() + .append(newInstance(UnsupportedOperationException.class, constantString("Aggregation is not decomposable"))) + .throwObject()))); + + // restore the parameter to the queue + remainingParameters.addFirst(blockParameter); + + return blockTypeSelection; + } + + private static MethodDefinition buildCoreLoop( + CallSiteBinder binder, + ClassDefinition classDefinition, + MethodHandle function, + List blockTypes, + AggregationParameters aggregationParameters) + { + StringBuilder methodName = new StringBuilder("invoke_"); + for (BlockType blockType : blockTypes) { + methodName.append(blockType.name().charAt(0)); + } + + MethodDefinition methodDefinition = classDefinition.declareMethod( + a(PUBLIC, STATIC), + methodName.toString(), + type(void.class), + aggregationParameters.allParameters()); + Scope scope = methodDefinition.getScope(); + BytecodeBlock body = methodDefinition.getBody(); + + Variable position = scope.declareVariable(int.class, "position"); + + ImmutableList.Builder aggregationArguments = ImmutableList.builder(); + aggregationArguments.addAll(aggregationParameters.states()); + addBlockPositionArguments(methodDefinition, position, blockTypes, aggregationParameters.blocks(), aggregationArguments); + aggregationArguments.addAll(aggregationParameters.lambdas()); + + BytecodeBlock invokeFunction = new BytecodeBlock(); + if (aggregationParameters.groupIds().isPresent()) { + // set groupId on state variables + Variable groupId = scope.declareVariable(int.class, "groupId"); + invokeFunction.append(groupId.set(aggregationParameters.groupIds().get().getElement(position))); + for (Parameter stateParameter : aggregationParameters.states()) { + invokeFunction.append(stateParameter.cast(GroupedAccumulatorState.class).invoke("setGroupId", void.class, groupId.cast(long.class))); + } + } + invokeFunction.append(invoke(binder.bind(function), "input", aggregationArguments.build())); + + Variable positionCount = scope.declareVariable("positionCount", body, aggregationParameters.mask().invoke("getSelectedPositionCount", int.class)); + + ForLoop selectAllLoop = new ForLoop() + .initialize(position.set(constantInt(0))) + .condition(lessThan(position, positionCount)) + .update(position.increment()) + .body(invokeFunction); + + Variable index = scope.declareVariable("index", body, constantInt(0)); + Variable selectedPositions = scope.declareVariable(int[].class, "selectedPositions"); + ForLoop maskedLoop = new ForLoop() + .initialize(selectedPositions.set(aggregationParameters.mask().invoke("getSelectedPositions", int[].class))) + .condition(lessThan(index, positionCount)) + .update(index.increment()) + .body(new BytecodeBlock() + .append(position.set(selectedPositions.getElement(index))) + .append(invokeFunction)); + + body.append(new IfStatement() + .condition(aggregationParameters.mask().invoke("isSelectAll", boolean.class)) + .ifTrue(selectAllLoop) + .ifFalse(maskedLoop)); + body.ret(); + return methodDefinition; + } + + private static void addBlockPositionArguments( + MethodDefinition methodDefinition, + Variable position, + List blockTypes, + List blockParameters, + ImmutableList.Builder aggregationArguments) + { + Scope scope = methodDefinition.getScope(); + BytecodeBlock methodBody = methodDefinition.getBody(); + + for (int i = 0; i < blockTypes.size(); i++) { + BlockType blockType = blockTypes.get(i); + switch (blockType) { + case VALUE -> { + aggregationArguments.add(blockParameters.get(i).cast(ValueBlock.class)); + aggregationArguments.add(position); + } + case DICTIONARY -> { + Variable valueBlock = scope.declareVariable( + "valueBlock" + i, + methodBody, + blockParameters.get(i).cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class)); + Variable rawIds = scope.declareVariable( + "rawIds" + i, + methodBody, + blockParameters.get(i).cast(DictionaryBlock.class).invoke("getRawIds", int[].class)); + Variable rawIdsOffset = scope.declareVariable( + "rawIdsOffset" + i, + methodBody, + blockParameters.get(i).cast(DictionaryBlock.class).invoke("getRawIdsOffset", int.class)); + aggregationArguments.add(valueBlock); + aggregationArguments.add(rawIds.getElement(BytecodeExpressions.add(rawIdsOffset, position))); + } + case RLE -> { + Variable valueBlock = scope.declareVariable( + "valueBlock" + i, + methodBody, + blockParameters.get(i).cast(RunLengthEncodedBlock.class).invoke("getValue", ValueBlock.class)); + aggregationArguments.add(valueBlock); + aggregationArguments.add(constantInt(0)); + } + } + } + } + + private static void verifyFunctionSignature(MethodHandle function, int stateCount, int parameterCount) + { + // verify signature + List> expectedParameterTypes = ImmutableList.>builder() + .addAll(function.type().parameterList().subList(0, stateCount)) + .addAll(limit(cycle(ValueBlock.class, int.class), parameterCount * 2)) + .addAll(function.type().parameterList().subList(stateCount + (parameterCount * 2), function.type().parameterCount())) + .build(); + MethodType expectedSignature = methodType(void.class, expectedParameterTypes); + checkArgument(function.type().equals(expectedSignature), "Expected function signature to be %s, but is %s", expectedSignature, function.type()); + } + + private record AggregationParameters(Parameter mask, Optional groupIds, List states, List blocks, List lambdas) + { + static AggregationParameters create(MethodHandle function, int stateCount, int parameterCount, boolean grouped) + { + Parameter mask = arg("aggregationMask", AggregationMask.class); + + Optional groupIds = Optional.empty(); + if (grouped) { + groupIds = Optional.of(arg("groupIds", int[].class)); + } + + ImmutableList.Builder states = ImmutableList.builder(); + for (int i = 0; i < stateCount; i++) { + states.add(arg("state" + i, function.type().parameterType(i))); + } + + ImmutableList.Builder parameters = ImmutableList.builder(); + for (int i = 0; i < parameterCount; i++) { + parameters.add(arg("block" + i, Block.class)); + } + + ImmutableList.Builder lambdas = ImmutableList.builder(); + int lambdaFunctionOffset = stateCount + (parameterCount * 2); + for (int i = 0; i < function.type().parameterCount() - lambdaFunctionOffset; i++) { + lambdas.add(arg("lambda" + i, function.type().parameterType(lambdaFunctionOffset + i))); + } + + return new AggregationParameters(mask, groupIds, states.build(), parameters.build(), lambdas.build()); + } + + public List allParameters() + { + return ImmutableList.builder() + .add(mask) + .addAll(groupIds.stream().iterator()) + .addAll(states) + .addAll(blocks) + .addAll(lambdas) + .build(); + } + } + + private enum BlockType + { + RLE, DICTIONARY, VALUE + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java index e28492becf3f..1679e3ece3ea 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java @@ -16,8 +16,8 @@ import com.google.common.annotations.VisibleForTesting; import io.airlift.stats.cardinality.HyperLogLog; import io.trino.operator.aggregation.state.HyperLogLogState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -52,7 +52,7 @@ private ApproximateCountDistinctAggregation() {} @InputFunction public static void input( @AggregationState HyperLogLogState state, - @BlockPosition @SqlType("unknown") Block block, + @BlockPosition @SqlType("unknown") ValueBlock block, @BlockIndex int index, @SqlType(StandardTypes.DOUBLE) double maxStandardError) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java index 459dabae0daf..4791fe78a83d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java @@ -16,8 +16,8 @@ import io.airlift.stats.cardinality.HyperLogLog; import io.trino.operator.aggregation.state.HyperLogLogState; import io.trino.operator.aggregation.state.StateCompiler; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -51,7 +51,7 @@ private ApproximateSetGenericAggregation() {} @InputFunction public static void input( @AggregationState HyperLogLogState state, - @BlockPosition @SqlType("unknown") Block block, + @BlockPosition @SqlType("unknown") ValueBlock block, @BlockIndex int index) { // do nothing -- unknown type is always NULL diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java index 28d778a5489a..e62a9ea91adf 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -37,7 +37,7 @@ private ArbitraryAggregationFunction() {} @TypeParameter("T") public static void input( @AggregationState("T") InOut state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java index 51452b0e4022..92bbc6e40326 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java @@ -17,8 +17,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.operator.aggregation.state.NullableLongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -36,7 +36,7 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -55,10 +55,10 @@ public static void input( @OperatorDependency( operator = OperatorType.XX_HASH_64, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle xxHash64Operator, @AggregationState NullableLongState state, - @SqlNullable @BlockPosition @SqlType("T") Block block, + @SqlNullable @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java index 9163dab81c84..87ccef50fbec 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -40,7 +40,7 @@ private CountColumn() {} @TypeParameter("T") public static void input( @AggregationState LongState state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) { state.setValue(state.getValue() + 1); @@ -49,7 +49,7 @@ public static void input( @RemoveInputFunction public static void removeInput( @AggregationState LongState state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) { state.setValue(state.getValue() - 1); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java index 69728b943941..c62754cac935 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java @@ -15,8 +15,8 @@ import com.google.common.annotations.VisibleForTesting; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -80,7 +80,7 @@ public static void inputShortDecimal( @LiteralParameters({"p", "s"}) public static void inputLongDecimal( @AggregationState LongDecimalWithOverflowAndLongState state, - @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Block block, + @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Int128ArrayBlock block, @BlockIndex int position) { state.addLong(1); // row counter diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java index 6439dbc23483..f256a8748777 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -66,7 +66,7 @@ public static void inputShortDecimal( @LiteralParameters({"p", "s"}) public static void inputLongDecimal( @AggregationState LongDecimalWithOverflowState state, - @BlockPosition @SqlType(value = "decimal(p,s)", nativeContainerType = Int128.class) Block block, + @BlockPosition @SqlType(value = "decimal(p,s)", nativeContainerType = Int128.class) Int128ArrayBlock block, @BlockIndex int position) { state.setNotNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java index ee7c9e7de10e..517fa4df07b4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.HyperLogLogState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -45,7 +45,7 @@ private DefaultApproximateCountDistinctAggregation() {} @InputFunction public static void input( @AggregationState HyperLogLogState state, - @BlockPosition @SqlType("unknown") Block block, + @BlockPosition @SqlType("unknown") ValueBlock block, @BlockIndex int index) { // do nothing diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java index 5ef75bbb6169..998a914830d6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java @@ -78,13 +78,15 @@ public void processPage(int groupCount, int[] groupIds, Page page) Page arguments = page.getColumns(inputChannels); Optional maskBlock = Optional.empty(); if (maskChannel.isPresent()) { - maskBlock = Optional.of(page.getBlock(maskChannel.getAsInt())); + maskBlock = Optional.of(page.getBlock(maskChannel.getAsInt()).getLoadedBlock()); } AggregationMask mask = maskBuilder.buildAggregationMask(arguments, maskBlock); if (mask.isSelectNone()) { return; } + // Unwrap any LazyBlock values before evaluating the accumulator + arguments = arguments.getLoadedPage(); accumulator.addInput(groupIds, arguments, mask); } else { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java index 2084d33a0964..0320d955a761 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.Type; @@ -65,7 +65,7 @@ public void ensureCapacity(long size) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(groupId, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java index 0cf618037f66..a1f1ce4b57cd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,11 +39,12 @@ private MapAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MapAggregationState state, - @BlockPosition @SqlType("K") Block key, - @SqlNullable @BlockPosition @SqlType("V") Block value, - @BlockIndex int position) + @BlockPosition @SqlType("K") ValueBlock key, + @BlockIndex int keyPosition, + @SqlNullable @BlockPosition @SqlType("V") ValueBlock value, + @BlockIndex int valuePosition) { - state.add(key, position, value, position); + state.add(key, keyPosition, value, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java index 0d4a7886a1d3..f1fdbe3122f9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -27,7 +28,7 @@ public interface MapAggregationState extends AccumulatorState { - void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); + void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition); default void merge(MapAggregationState other) { @@ -36,8 +37,10 @@ default void merge(MapAggregationState other) Block rawKeyBlock = serializedState.getRawKeyBlock(); Block rawValueBlock = serializedState.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); + ValueBlock rawValueValues = rawValueBlock.getUnderlyingValueBlock(); for (int i = 0; i < serializedState.getSize(); i++) { - add(rawKeyBlock, rawOffset + i, rawValueBlock, rawOffset + i); + add(rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i), rawValueValues, rawValueBlock.getUnderlyingValuePosition(rawOffset + i)); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java index 8f6ae5c435db..ddb2a4630a54 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -52,7 +52,7 @@ public MapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", @@ -60,11 +60,11 @@ public MapAggregationStateFactory( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, @TypeParameter("V") Type valueType, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -73,7 +73,7 @@ public MapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) { this.keyType = requireNonNull(keyType, "keyType is null"); this.keyReadFlat = requireNonNull(keyReadFlat, "keyReadFlat is null"); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java index 9a247f9e8b82..718090b4601f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java @@ -17,6 +17,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; @@ -45,8 +46,10 @@ public static void input( Block rawKeyBlock = value.getRawKeyBlock(); Block rawValueBlock = value.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); + ValueBlock rawValueValues = rawValueBlock.getUnderlyingValueBlock(); for (int i = 0; i < value.getSize(); i++) { - state.add(rawKeyBlock, rawOffset + i, rawValueBlock, rawOffset + i); + state.add(rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i), rawValueValues, rawValueBlock.getUnderlyingValuePosition(rawOffset + i)); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java index b3b500720ddc..6ec2f540c84f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -32,8 +32,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("max") @@ -48,10 +48,10 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("T") InOut state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java index c3fc6fd8a6ab..1e7a2f1294d9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -33,8 +33,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("max_by") @@ -50,18 +50,19 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("K") InOut keyState, @AggregationState("V") InOut valueState, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @BlockIndex int position) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition) throws Throwable { - if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) > 0) { - keyState.set(keyBlock, position); - valueState.set(valueBlock, position); + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, keyPosition, keyState)) > 0) { + keyState.set(keyBlock, keyPosition); + valueState.set(valueBlock, valuePosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java index d2ab6797150c..317e16ba8649 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -40,7 +40,7 @@ private MaxDataSizeForStats() {} @InputFunction @TypeParameter("T") - public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") Block block, @BlockIndex int index) + public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int index) { update(state, block.getEstimatedDataSizeForStats(index)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java index 2c18f112974d..5076734a0a93 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java @@ -15,8 +15,8 @@ import io.airlift.stats.QuantileDigest; import io.trino.operator.aggregation.state.QuantileDigestState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -45,7 +45,7 @@ private MergeQuantileDigestFunction() {} public static void input( @TypeParameter("qdigest(V)") Type type, @AggregationState QuantileDigestState state, - @BlockPosition @SqlType("qdigest(V)") Block value, + @BlockPosition @SqlType("qdigest(V)") ValueBlock value, @BlockIndex int index) { merge(state, new QuantileDigest(type.getSlice(value, index))); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java index acf8e408dbea..8616b7c2116c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -32,8 +32,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("min") @@ -48,10 +48,10 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("T") InOut state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java index 6d8520d3cf0f..3c79a80adc1f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -33,8 +33,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("min_by") @@ -50,18 +50,19 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("K") InOut keyState, @AggregationState("V") InOut valueState, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @BlockIndex int position) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition) throws Throwable { - if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) < 0) { - keyState.set(keyBlock, position); - valueState.set(valueBlock, position); + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, keyPosition, keyState)) < 0) { + keyState.set(keyBlock, keyPosition); + valueState.set(valueBlock, valuePosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java index 3aef3c7f2ff5..ba77f57d03dc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java @@ -19,7 +19,7 @@ import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import io.trino.operator.annotations.ImplementationDependency; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -220,7 +220,7 @@ public boolean areTypesAssignable(BoundSignature boundSignature) // block and position works for any type, but if block is annotated with SqlType nativeContainerType, then only types with the // specified container type match - if (isCurrentBlockPosition && methodDeclaredType.isAssignableFrom(Block.class)) { + if (isCurrentBlockPosition && ValueBlock.class.isAssignableFrom(methodDeclaredType)) { continue; } if (methodDeclaredType.isAssignableFrom(argumentType)) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java index 3606543c8f9f..1d6fd3b8421d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.type.Type; @@ -61,7 +61,7 @@ private SingleMapAggregationState(SingleMapAggregationState state) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(0, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java index 415299c729d0..04f2f607f8dd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -38,7 +38,7 @@ private SumDataSizeForStats() {} @InputFunction @TypeParameter("T") - public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") Block block, @BlockIndex int index) + public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int index) { update(state, block.getEstimatedDataSizeForStats(index)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java index 515780bb75fc..d4e969d50bd6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation.arrayagg; import io.trino.spi.block.ArrayBlockBuilder; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,7 +39,7 @@ private ArrayAggregationFunction() {} @TypeParameter("T") public static void input( @AggregationState("T") ArrayAggregationState state, - @SqlNullable @BlockPosition @SqlType("T") Block value, + @SqlNullable @BlockPosition @SqlType("T") ValueBlock value, @BlockIndex int position) { state.add(value, position); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java index 77018642d751..4488e1708ff4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -26,9 +27,7 @@ public interface ArrayAggregationState extends AccumulatorState { - void addAll(Block block); - - void add(Block block, int position); + void add(ValueBlock block, int position); void writeAll(BlockBuilder blockBuilder); @@ -36,6 +35,10 @@ public interface ArrayAggregationState default void merge(ArrayAggregationState otherState) { - addAll(((SingleArrayAggregationState) otherState).removeTempDeserializeBlock()); + Block block = ((SingleArrayAggregationState) otherState).removeTempDeserializeBlock(); + ValueBlock valueBlock = block.getUnderlyingValueBlock(); + for (int position = 0; position < block.getPositionCount(); position++) { + add(valueBlock, block.getUnderlyingValuePosition(position)); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java index 7694e78b75aa..9176c313398e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -43,7 +43,7 @@ public ArrayAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @TypeParameter("T") Type type) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java index 1e2e3be1140f..57c2121508b2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java @@ -15,8 +15,8 @@ import com.google.common.base.Throwables; import io.trino.operator.VariableWidthData; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -154,7 +154,7 @@ public void setNextIndex(long tailIndex, long nextIndex) LONG_HANDLE.set(records, recordOffset + recordNextIndexOffset, nextIndex); } - public void add(Block block, int position) + public void add(ValueBlock block, int position) { if (size == capacity) { growCapacity(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java index 5d5f3e9bcba3..84381a9aa0ac 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java @@ -15,8 +15,8 @@ import com.google.common.primitives.Ints; import io.trino.operator.aggregation.state.AbstractGroupedAccumulatorState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -72,15 +72,7 @@ public void ensureCapacity(long maxGroupId) } @Override - public void addAll(Block block) - { - for (int position = 0; position < block.getPositionCount(); position++) { - add(block, position); - } - } - - @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { int groupId = (int) getGroupId(); long index = arrayBuilder.size(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java index 64acf0744148..30fcb7acdbc3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -52,15 +53,7 @@ public long getEstimatedSize() } @Override - public void addAll(Block block) - { - for (int position = 0; position < block.getPositionCount(); position++) { - add(block, position); - } - } - - @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { arrayBuilder.add(block, position); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index c854e3258616..bcc661dfb389 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -147,10 +147,10 @@ public void close() {} public Work processPage(Page page) { if (groupedAggregators.isEmpty()) { - return groupByHash.addPage(page.getColumns(groupByChannels)); + return groupByHash.addPage(page.getLoadedPage(groupByChannels)); } return new TransformWork<>( - groupByHash.getGroupIds(page.getColumns(groupByChannels)), + groupByHash.getGroupIds(page.getLoadedPage(groupByChannels)), groupByIdBlock -> { int groupCount = groupByHash.getGroupCount(); for (GroupedAggregator groupedAggregator : groupedAggregators) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java index 2697ae84556b..ad9675303706 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java @@ -15,8 +15,8 @@ package io.trino.operator.aggregation.histogram; import io.trino.operator.aggregation.state.AbstractGroupedAccumulatorState; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -50,7 +50,7 @@ public void ensureCapacity(long size) } @Override - public void add(Block block, int position, long count) + public void add(ValueBlock block, int position, long count) { histogram.add(toIntExact(getGroupId()), block, position, count); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java index 7dccd97cd48e..a835c4780cc0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.histogram; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,7 +39,7 @@ private Histogram() {} public static void input( @TypeParameter("T") Type type, @AggregationState("T") HistogramState state, - @BlockPosition @SqlType("T") Block key, + @BlockPosition @SqlType("T") ValueBlock key, @BlockIndex int position) { state.add(key, position, 1L); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java index 32b18321b2ee..b0ae54e64333 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -29,7 +30,7 @@ public interface HistogramState extends AccumulatorState { - void add(Block block, int position, long count); + void add(ValueBlock block, int position, long count); default void merge(HistogramState other) { @@ -38,8 +39,10 @@ default void merge(HistogramState other) Block rawKeyBlock = serializedState.getRawKeyBlock(); Block rawValueBlock = serializedState.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); + ValueBlock rawValueValues = rawValueBlock.getUnderlyingValueBlock(); for (int i = 0; i < serializedState.getSize(); i++) { - add(rawKeyBlock, rawOffset + i, BIGINT.getLong(rawValueBlock, rawOffset + i)); + add(rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i), BIGINT.getLong(rawValueValues, rawValueBlock.getUnderlyingValuePosition(rawOffset + i))); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java index 096c10e044e9..4a11a67f39d1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -48,7 +48,7 @@ public HistogramStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", @@ -56,11 +56,11 @@ public HistogramStateFactory( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock) { this.type = requireNonNull(type, "type is null"); this.readFlat = requireNonNull(readFlat, "readFlat is null"); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java index 73a15cc2dd71..c6a3494b6bae 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java @@ -14,9 +14,9 @@ package io.trino.operator.aggregation.histogram; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -56,7 +56,7 @@ public SingleHistogramState( } @Override - public void add(Block block, int position, long count) + public void add(ValueBlock block, int position, long count) { if (typedHistogram == null) { typedHistogram = new TypedHistogram(keyType, readFlat, writeFlat, hashFlat, distinctFlatBlock, hashBlock, false); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java index 95656f8e6ca3..e40f503047a0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java @@ -17,9 +17,9 @@ import com.google.common.primitives.Ints; import io.trino.operator.VariableWidthData; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -237,7 +237,7 @@ private void serializeEntry(BlockBuilder keyBuilder, BlockBuilder valueBuilder, BIGINT.writeLong(valueBuilder, (long) LONG_HANDLE.get(records, recordOffset + recordCountOffset)); } - public void add(int groupId, Block block, int position, long count) + public void add(int groupId, ValueBlock block, int position, long count) { checkArgument(!block.isNull(position), "value must not be null"); checkArgument(groupId == 0 || groupRecordIndex != null, "groupId must be zero when grouping is not enabled"); @@ -275,7 +275,7 @@ public void add(int groupId, Block block, int position, long count) } } - private int matchInVector(int groupId, Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(int groupId, ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -306,7 +306,7 @@ private void addCount(int index, long increment) LONG_HANDLE.set(records, countOffset, (long) LONG_HANDLE.get(records, countOffset) + increment); } - private void insert(int index, int groupId, Block block, int position, long count, byte hashPrefix) + private void insert(int index, int groupId, ValueBlock block, int position, long count, byte hashPrefix) { setControl(index, hashPrefix); @@ -455,7 +455,7 @@ private long valueHashCode(int groupId, byte[] records, int index) } } - private long valueHashCode(int groupId, Block right, int rightPosition) + private long valueHashCode(int groupId, ValueBlock right, int rightPosition) { try { long valueHash = (long) hashBlock.invokeExact(right, rightPosition); @@ -467,7 +467,7 @@ private long valueHashCode(int groupId, Block right, int rightPosition) } } - private boolean valueNotDistinctFrom(int leftPosition, Block right, int rightPosition, int rightGroupId) + private boolean valueNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition, int rightGroupId) { byte[] leftRecords = getRecords(leftPosition); int leftRecordOffset = getRecordOffset(leftPosition); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java index ea225a2a9af4..d06b00295590 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.ArrayType; @@ -153,7 +154,7 @@ void setMaxOutputLength(int maxOutputLength) } @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { checkArgument(!block.isNull(position), "element is null"); @@ -231,9 +232,10 @@ public void merge(ListaggAggregationState other) boolean showOverflowEntryCount = BOOLEAN.getBoolean(fields.get(3), index); initialize(separator, overflowError, overflowFiller, showOverflowEntryCount); - Block values = new ArrayType(VARCHAR).getObject(fields.get(4), index); - for (int i = 0; i < values.getPositionCount(); i++) { - add(values, i); + Block array = new ArrayType(VARCHAR).getObject(fields.get(4), index); + ValueBlock arrayValues = array.getUnderlyingValueBlock(); + for (int i = 0; i < array.getPositionCount(); i++) { + add(arrayValues, arrayValues.getUnderlyingValuePosition(i)); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java index 716b10607657..863c3987c51a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java @@ -15,7 +15,7 @@ import com.google.common.primitives.Ints; import io.airlift.slice.SliceOutput; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.GroupedAccumulatorState; @@ -102,7 +102,7 @@ public void ensureCapacity(long maxGroupId) } @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { super.add(block, position); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java index 1bd177cedc97..738b1ffb806a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation.listagg; import io.airlift.slice.Slice; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -36,12 +36,12 @@ private ListaggAggregationFunction() {} @InputFunction public static void input( @AggregationState ListaggAggregationState state, - @BlockPosition @SqlType("VARCHAR") Block value, + @BlockPosition @SqlType("VARCHAR") ValueBlock value, + @BlockIndex int position, @SqlType("VARCHAR") Slice separator, @SqlType("BOOLEAN") boolean overflowError, @SqlType("VARCHAR") Slice overflowFiller, - @SqlType("BOOLEAN") boolean showOverflowEntryCount, - @BlockIndex int position) + @SqlType("BOOLEAN") boolean showOverflowEntryCount) { state.initialize(separator, overflowError, overflowFiller, showOverflowEntryCount); state.add(value, position); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java index 107350904832..c5b168c06d3d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation.listagg; import io.airlift.slice.Slice; -import io.trino.spi.block.Block; import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -28,7 +28,7 @@ public interface ListaggAggregationState { void initialize(Slice separator, boolean overflowError, Slice overflowFiller, boolean showOverflowEntryCount); - void add(Block block, int position); + void add(ValueBlock block, int position); void serialize(RowBlockBuilder out); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java index bbeef22c095c..2b3ed17f7512 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxbyn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -38,13 +38,14 @@ private MaxByNAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MaxByNState state, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition, + @SqlType("BIGINT") long n) { state.initialize(n); - state.add(keyBlock, valueBlock, blockIndex); + state.add(keyBlock, keyPosition, valueBlock, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java index a1fa006bf3fa..6c2b06af5ddf 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java @@ -27,8 +27,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -50,7 +50,7 @@ public MaxByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -60,7 +60,7 @@ public MaxByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, @@ -70,7 +70,7 @@ public MaxByNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java index 5036d19f87ab..451240b03d6b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxbyn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -38,13 +38,14 @@ private MinByNAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MinByNState state, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition, + @SqlType("BIGINT") long n) { state.initialize(n); - state.add(keyBlock, valueBlock, blockIndex); + state.add(keyBlock, keyPosition, valueBlock, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java index 79404c8e337b..644f586789e9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java @@ -27,8 +27,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -50,7 +50,7 @@ public MinByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -60,7 +60,7 @@ public MinByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, @@ -70,7 +70,7 @@ public MinByNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java index 516a0d2fea1b..adf254926460 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxbyn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; public interface MinMaxByNState @@ -29,7 +29,7 @@ public interface MinMaxByNState /** * Adds the value to this state. */ - void add(Block keyBlock, Block valueBlock, int position); + void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition); /** * Merge with the specified state. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java index b69de366e942..a2b63e971139 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java @@ -19,6 +19,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.ArrayType; @@ -50,7 +51,12 @@ public final void merge(MinMaxByNState other) Block keys = new ArrayType(typedKeyValueHeap.getKeyType()).getObject(sqlRow.getRawFieldBlock(1), rawIndex); Block values = new ArrayType(typedKeyValueHeap.getValueType()).getObject(sqlRow.getRawFieldBlock(2), rawIndex); - typedKeyValueHeap.addAll(keys, values); + + ValueBlock rawKeyValues = keys.getUnderlyingValueBlock(); + ValueBlock rawValueValues = values.getUnderlyingValueBlock(); + for (int i = 0; i < keys.getPositionCount(); i++) { + typedKeyValueHeap.add(rawKeyValues, keys.getUnderlyingValuePosition(i), rawValueValues, values.getUnderlyingValuePosition(i)); + } } @Override @@ -118,12 +124,12 @@ public final void initialize(long n) } @Override - public final void add(Block keyBlock, Block valueBlock, int position) + public final void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); size -= typedHeap.getEstimatedSize(); - typedHeap.add(keyBlock, valueBlock, position); + typedHeap.add(keyBlock, keyPosition, valueBlock, valuePosition); size += typedHeap.getEstimatedSize(); } @@ -203,9 +209,9 @@ public final void initialize(long n) } @Override - public final void add(Block keyBlock, Block valueBlock, int position) + public final void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { - typedHeap.add(keyBlock, valueBlock, position); + typedHeap.add(keyBlock, keyPosition, valueBlock, valuePosition); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java index afe2b2b6a394..4b7afb267fa2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java @@ -16,8 +16,8 @@ import com.google.common.base.Throwables; import io.airlift.slice.SizeOf; import io.trino.operator.VariableWidthData; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrays; import jakarta.annotation.Nullable; @@ -227,27 +227,20 @@ private void write(int index, @Nullable BlockBuilder keyBlockBuilder, BlockBuild } } - public void addAll(Block keyBlock, Block valueBlock) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { - for (int i = 0; i < keyBlock.getPositionCount(); i++) { - add(keyBlock, valueBlock, i); - } - } - - public void add(Block keyBlock, Block valueBlock, int position) - { - checkArgument(!keyBlock.isNull(position)); + checkArgument(!keyBlock.isNull(keyPosition)); if (positionCount == capacity) { // is it possible the value is within the top N values? - if (!shouldConsiderValue(keyBlock, position)) { + if (!shouldConsiderValue(keyBlock, keyPosition)) { return; } clear(0); - set(0, keyBlock, valueBlock, position); + set(0, keyBlock, keyPosition, valueBlock, valuePosition); siftDown(); } else { - set(positionCount, keyBlock, valueBlock, position); + set(positionCount, keyBlock, keyPosition, valueBlock, valuePosition); positionCount++; siftUp(); } @@ -274,7 +267,7 @@ private void clear(int index) }); } - private void set(int index, Block keyBlock, Block valueBlock, int position) + private void set(int index, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { int recordOffset = getRecordOffset(index); @@ -283,28 +276,28 @@ private void set(int index, Block keyBlock, Block valueBlock, int position) int keyVariableWidthLength = 0; if (variableWidthData != null) { if (keyVariableWidth) { - keyVariableWidthLength = keyType.getFlatVariableWidthSize(keyBlock, position); + keyVariableWidthLength = keyType.getFlatVariableWidthSize(keyBlock, keyPosition); } - int valueVariableWidthLength = valueType.getFlatVariableWidthSize(valueBlock, position); + int valueVariableWidthLength = valueType.getFlatVariableWidthSize(valueBlock, valuePosition); variableWidthChunk = variableWidthData.allocate(fixedChunk, recordOffset, keyVariableWidthLength + valueVariableWidthLength); variableWidthChunkOffset = getChunkOffset(fixedChunk, recordOffset); } try { - keyWriteFlat.invokeExact(keyBlock, position, fixedChunk, recordOffset + recordKeyOffset, variableWidthChunk, variableWidthChunkOffset); + keyWriteFlat.invokeExact(keyBlock, keyPosition, fixedChunk, recordOffset + recordKeyOffset, variableWidthChunk, variableWidthChunkOffset); } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); throw new RuntimeException(throwable); } - if (valueBlock.isNull(position)) { + if (valueBlock.isNull(valuePosition)) { fixedChunk[recordOffset + recordKeyOffset - 1] = 1; } else { try { valueWriteFlat.invokeExact( valueBlock, - position, + valuePosition, fixedChunk, recordOffset + recordValueOffset, variableWidthChunk, @@ -394,7 +387,7 @@ private int compare(int leftPosition, int rightPosition) } } - private boolean shouldConsiderValue(Block right, int rightPosition) + private boolean shouldConsiderValue(ValueBlock right, int rightPosition) { byte[] leftFixedRecordChunk = fixedChunk; int leftRecordOffset = getRecordOffset(0); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java index 1c385121520f..df02803a46f4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -36,9 +36,9 @@ private MaxNAggregationFunction() {} @TypeParameter("E") public static void input( @AggregationState("E") MaxNState state, - @BlockPosition @SqlType("E") Block block, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockPosition @SqlType("E") ValueBlock block, + @BlockIndex int blockIndex, + @SqlType("BIGINT") long n) { state.initialize(n); state.add(block, blockIndex); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java index 9a49b8057a46..81c8fa1b9681 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java @@ -27,8 +27,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -50,7 +50,7 @@ public MaxNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, @@ -60,7 +60,7 @@ public MaxNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("T") Type elementType) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java index 97144ddc2f04..1c61c61fec2d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; public interface MinMaxNState @@ -29,7 +29,7 @@ public interface MinMaxNState /** * Adds the value to this state. */ - void add(Block block, int position); + void add(ValueBlock block, int position); /** * Merge with the specified state. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java index bd74cce3dbc9..fc94133942b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java @@ -19,6 +19,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.ArrayType; @@ -51,8 +52,11 @@ public final void merge(MinMaxNState other) initialize(capacity); TypedHeap typedHeap = getTypedHeap(); - Block values = new ArrayType(typedHeap.getElementType()).getObject(sqlRow.getRawFieldBlock(1), rawIndex); - typedHeap.addAll(values); + Block array = new ArrayType(typedHeap.getElementType()).getObject(sqlRow.getRawFieldBlock(1), rawIndex); + ValueBlock arrayValues = array.getUnderlyingValueBlock(); + for (int i = 0; i < array.getPositionCount(); i++) { + typedHeap.add(arrayValues, array.getUnderlyingValuePosition(i)); + } } @Override @@ -118,7 +122,7 @@ public final void initialize(long n) } @Override - public final void add(Block block, int position) + public final void add(ValueBlock block, int position) { TypedHeap typedHeap = getTypedHeap(); @@ -200,7 +204,7 @@ public final void initialize(long n) } @Override - public final void add(Block block, int position) + public final void add(ValueBlock block, int position) { typedHeap.add(block, position); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java index 4521f979d6ce..3f4f5a78ceb0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -36,9 +36,9 @@ private MinNAggregationFunction() {} @TypeParameter("E") public static void input( @AggregationState("E") MinNState state, - @BlockPosition @SqlType("E") Block block, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockPosition @SqlType("E") ValueBlock block, + @BlockIndex int blockIndex, + @SqlType("BIGINT") long n) { state.initialize(n); state.add(block, blockIndex); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java index 46715f2115bc..99fe5b6496d4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java @@ -26,8 +26,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -49,7 +49,7 @@ public MinNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, @@ -59,7 +59,7 @@ public MinNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("T") Type elementType) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java index 586e0e372477..7ba0168077d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java @@ -17,8 +17,8 @@ import com.google.common.primitives.Ints; import io.airlift.slice.SizeOf; import io.trino.operator.VariableWidthData; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrays; @@ -184,14 +184,7 @@ private void write(int index, BlockBuilder blockBuilder) } } - public void addAll(Block block) - { - for (int i = 0; i < block.getPositionCount(); i++) { - add(block, i); - } - } - - public void add(Block block, int position) + public void add(ValueBlock block, int position) { checkArgument(!block.isNull(position)); if (positionCount == capacity) { @@ -227,7 +220,7 @@ private void clear(int index) elementType.relocateFlatVariableWidthOffsets(fixedChunk, fixedSizeOffset + recordElementOffset, variableWidthChunk, variableWidthChunkOffset)); } - private void set(int index, Block block, int position) + private void set(int index, ValueBlock block, int position) { int recordOffset = getRecordOffset(index); @@ -325,7 +318,7 @@ private int compare(int leftPosition, int rightPosition) } } - private boolean shouldConsiderValue(Block right, int rightPosition) + private boolean shouldConsiderValue(ValueBlock right, int rightPosition) { byte[] leftFixedRecordChunk = fixedChunk; int leftRecordOffset = getRecordOffset(0); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java index 1ff2b3e4d7d5..5a69677e9168 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -299,24 +300,26 @@ protected void deserialize(int groupId, SqlMap serializedState) Block rawKeyBlock = serializedState.getRawKeyBlock(); Block rawValueBlock = serializedState.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); ArrayType arrayType = new ArrayType(valueArrayBuilder.type()); for (int i = 0; i < serializedState.getSize(); i++) { - int keyId = putKeyIfAbsent(groupId, rawKeyBlock, rawOffset + i); + int keyId = putKeyIfAbsent(groupId, rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i)); Block array = arrayType.getObject(rawValueBlock, rawOffset + i); verify(array.getPositionCount() > 0, "array is empty"); + ValueBlock arrayValuesBlock = array.getUnderlyingValueBlock(); for (int arrayIndex = 0; arrayIndex < array.getPositionCount(); arrayIndex++) { - addKeyValue(keyId, array, arrayIndex); + addKeyValue(keyId, arrayValuesBlock, array.getUnderlyingValuePosition(arrayIndex)); } } } - protected void add(int groupId, Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + protected void add(int groupId, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { int keyId = putKeyIfAbsent(groupId, keyBlock, keyPosition); addKeyValue(keyId, valueBlock, valuePosition); } - private int putKeyIfAbsent(int groupId, Block keyBlock, int keyPosition) + private int putKeyIfAbsent(int groupId, ValueBlock keyBlock, int keyPosition) { checkArgument(!keyBlock.isNull(keyPosition), "key must not be null"); checkArgument(groupId == 0 || groupRecordIndex != null, "groupId must be zero when grouping is not enabled"); @@ -356,7 +359,7 @@ private int putKeyIfAbsent(int groupId, Block keyBlock, int keyPosition) } } - private int matchInVector(int groupId, Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(int groupId, ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -380,7 +383,7 @@ private int findEmptyInVector(long vector, int vectorStartBucket) return bucket(vectorStartBucket + slot); } - private int insert(int keyIndex, int groupId, Block keyBlock, int keyPosition, byte hashPrefix) + private int insert(int keyIndex, int groupId, ValueBlock keyBlock, int keyPosition, byte hashPrefix) { setControl(keyIndex, hashPrefix); @@ -430,7 +433,7 @@ private int insert(int keyIndex, int groupId, Block keyBlock, int keyPosition, b return keyId; } - private void addKeyValue(int keyId, Block valueBlock, int valuePosition) + private void addKeyValue(int keyId, ValueBlock valueBlock, int valuePosition) { long index = valueArrayBuilder.size(); if (keyTailPositions[keyId] == -1) { @@ -554,7 +557,7 @@ private long keyHashCode(int groupId, byte[] records, int index) } } - private long keyHashCode(int groupId, Block right, int rightPosition) + private long keyHashCode(int groupId, ValueBlock right, int rightPosition) { try { long valueHash = (long) keyHashBlock.invokeExact(right, rightPosition); @@ -566,7 +569,7 @@ private long keyHashCode(int groupId, Block right, int rightPosition) } } - private boolean keyNotDistinctFrom(int leftPosition, Block right, int rightPosition, int rightGroupId) + private boolean keyNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition, int rightGroupId) { byte[] leftRecords = getRecords(leftPosition); int leftRecordOffset = getRecordOffset(leftPosition); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java index e2f5fd079cee..3116bbcd16d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.Type; @@ -66,7 +66,7 @@ public void ensureCapacity(long size) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(groupId, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java index d4f64d9a3660..374de6495cee 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,11 +39,12 @@ private MultimapAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MultimapAggregationState state, - @BlockPosition @SqlType("K") Block key, - @SqlNullable @BlockPosition @SqlType("V") Block value, - @BlockIndex int position) + @BlockPosition @SqlType("K") ValueBlock key, + @BlockIndex int keyPosition, + @SqlNullable @BlockPosition @SqlType("V") ValueBlock value, + @BlockIndex int valuePosition) { - state.add(key, position, value, position); + state.add(key, keyPosition, value, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java index 88a587be3e08..87143c148ec3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -26,7 +26,7 @@ public interface MultimapAggregationState extends AccumulatorState { - void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); + void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition); void merge(MultimapAggregationState other); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java index 7e5615dacb7c..a0682a133f0a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -52,7 +52,7 @@ public MultimapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", @@ -60,11 +60,11 @@ public MultimapAggregationStateFactory( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, @TypeParameter("V") Type valueType, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -73,7 +73,7 @@ public MultimapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) { this.keyType = requireNonNull(keyType, "keyType is null"); this.keyReadFlat = requireNonNull(keyReadFlat, "keyReadFlat is null"); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java index 269cab614349..65fa48b81eb1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.type.Type; @@ -61,7 +61,7 @@ private SingleMultimapAggregationState(SingleMultimapAggregationState state) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(0, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java index 20a01eb430d6..2c065f182238 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java @@ -36,7 +36,7 @@ public LocalPartitionGenerator(HashGenerator hashGenerator, int partitionCount) } @Override - public int getPartitionCount() + public int partitionCount() { return partitionCount; } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java index ac62fb3eee14..9708ae7092e4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java @@ -149,7 +149,7 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result) { int positionCount = positions.length; - int partitionCount = partitionGenerator.getPartitionCount(); + int partitionCount = partitionGenerator.partitionCount(); int[] partitions = new int[positionCount]; int[] partitionPositionsCount = new int[partitionCount]; diff --git a/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java index d640db4e0208..73dbf4360055 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public BytePositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof ByteArrayBlock, "Block must be instance of %s", ByteArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof ByteArrayBlock, "Block must be instance of %s", ByteArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof ByteArrayBlock, "Block must be instance of %s", ByteArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java index d90694082928..d7125a989c60 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; @@ -56,9 +58,10 @@ public Fixed12PositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof Fixed12Block, "Block must be instance of %s", Fixed12Block.class); + if (positions.isEmpty()) { return; } @@ -100,8 +103,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof Fixed12Block, "Block must be instance of %s", Fixed12Block.class); + if (rlePositionCount == 0) { return; } @@ -130,8 +135,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof Fixed12Block, "Block must be instance of %s", Fixed12Block.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -202,7 +209,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize * 3); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java index fceb70eb4d28..4198091cc548 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; @@ -56,9 +58,10 @@ public Int128PositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof Int128ArrayBlock, "Block must be instance of %s", Int128ArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -101,8 +104,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof Int128ArrayBlock, "Block must be instance of %s", Int128ArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -129,8 +134,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof Int128ArrayBlock, "Block must be instance of %s", Int128ArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -200,7 +207,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize * 2); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java index f4b28b1c5a0b..290d395d5d75 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public IntPositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof IntArrayBlock, "Block must be instance of %s", IntArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof IntArrayBlock, "Block must be instance of %s", IntArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof IntArrayBlock, "Block must be instance of %s", IntArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java index 2a5910efdec0..1b378ca502b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public LongPositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof LongArrayBlock, "Block must be instance of %s", LongArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof LongArrayBlock, "Block must be instance of %s", LongArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof LongArrayBlock, "Block must be instance of %s", LongArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java index 72be3d3e4277..78e4a2ff12f5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java @@ -115,7 +115,7 @@ public PagePartitioner( } } - int partitionCount = partitionFunction.getPartitionCount(); + int partitionCount = partitionFunction.partitionCount(); int pageSize = toIntExact(min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, maxMemory.toBytes() / partitionCount)); pageSize = max(1, pageSize); @@ -146,7 +146,7 @@ public void partitionPage(Page page) return; } - if (page.getPositionCount() < partitionFunction.getPartitionCount() * COLUMNAR_STRATEGY_COEFFICIENT) { + if (page.getPositionCount() < partitionFunction.partitionCount() * COLUMNAR_STRATEGY_COEFFICIENT) { // Partition will have on average less than COLUMNAR_STRATEGY_COEFFICIENT rows. // Doing it column-wise would degrade performance, so we fall back to row-wise approach. // Performance degradation is the worst in case of skewed hash distribution when only small subset @@ -209,7 +209,7 @@ public void partitionPageByColumn(Page page) { IntArrayList[] partitionedPositions = partitionPositions(page); - for (int i = 0; i < partitionFunction.getPartitionCount(); i++) { + for (int i = 0; i < partitionFunction.partitionCount(); i++) { IntArrayList partitionPositions = partitionedPositions[i]; if (!partitionPositions.isEmpty()) { positionsAppenders[i].appendToOutputPartition(page, partitionPositions); @@ -259,9 +259,9 @@ private IntArrayList[] initPositions(Page page) // want memory to explode in case there are input pages with many positions, where each page // is assigned to a single partition entirely. // For example this can happen for partition columns if they are represented by RLE blocks. - IntArrayList[] partitionPositions = new IntArrayList[partitionFunction.getPartitionCount()]; + IntArrayList[] partitionPositions = new IntArrayList[partitionFunction.partitionCount()]; for (int i = 0; i < partitionPositions.length; i++) { - partitionPositions[i] = new IntArrayList(initialPartitionSize(page.getPositionCount() / partitionFunction.getPartitionCount())); + partitionPositions[i] = new IntArrayList(initialPartitionSize(page.getPositionCount() / partitionFunction.partitionCount())); } return partitionPositions; } @@ -275,7 +275,7 @@ private static int initialPartitionSize(int averagePositionsPerPartition) return (int) (averagePositionsPerPartition * 1.1) + 32; } - private boolean onlyRleBlocks(Page page) + private static boolean onlyRleBlocks(Page page) { for (int i = 0; i < page.getChannelCount(); i++) { if (!(page.getBlock(i) instanceof RunLengthEncodedBlock)) { @@ -308,7 +308,7 @@ private void partitionBySingleRleValue(Page page, int position, Page partitionFu } } - private Page extractRlePage(Page page) + private static Page extractRlePage(Page page) { Block[] valueBlocks = new Block[page.getChannelCount()]; for (int channel = 0; channel < valueBlocks.length; ++channel) { @@ -317,7 +317,7 @@ private Page extractRlePage(Page page) return new Page(valueBlocks); } - private int[] integersInRange(int start, int endExclusive) + private static int[] integersInRange(int start, int endExclusive) { int[] array = new int[endExclusive - start]; int current = start; @@ -327,7 +327,7 @@ private int[] integersInRange(int start, int endExclusive) return array; } - private boolean isDictionaryProcessingFaster(Block block) + private static boolean isDictionaryProcessingFaster(Block block) { if (!(block instanceof DictionaryBlock dictionaryBlock)) { return false; @@ -386,7 +386,7 @@ private void partitionNullablePositions(Page page, int position, IntArrayList[] } } - private void partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) + private static void partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) { int positionCount = page.getPositionCount(); int[] partitionPerPosition = new int[positionCount]; diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java index 2479f76a7941..1d47760e4e66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java @@ -35,7 +35,7 @@ public class PagePartitionerPool * In normal conditions, in the steady state, * the number of free {@link PagePartitioner}s is going to be close to 0. * There is a possible case though, where initially big number of concurrent drivers, say 128, - * drops to a small number e.g. 32 in a steady state. This could cause a lot of memory + * drops to a small number e.g., 32 in a steady state. This could cause a lot of memory * to be retained by the unused buffers. * To defend against that, {@link #maxFree} limits the number of free buffers, * thus limiting unused memory. diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java index 4b0a38b61a53..e861c4e0f685 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java @@ -14,27 +14,28 @@ package io.trino.operator.output; import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; public interface PositionsAppender { - void append(IntArrayList positions, Block source); + void append(IntArrayList positions, ValueBlock source); /** * Appends the specified value positionCount times. - * The result is the same as with using {@link PositionsAppender#append(IntArrayList, Block)} with - * positions list [0...positionCount -1] but with possible performance optimizations. + * The result is the same as with using {@link PositionsAppender#append(IntArrayList, ValueBlock)} with + * a position list [0...positionCount -1] but with possible performance optimizations. */ - void appendRle(Block value, int rlePositionCount); + void appendRle(ValueBlock value, int rlePositionCount); /** * Appends single position. The implementation must be conceptually equal to * {@code append(IntArrayList.wrap(new int[] {position}), source)} but may be optimized. - * Caller should avoid using this method if {@link #append(IntArrayList, Block)} can be used + * Caller should avoid using this method if {@link #append(IntArrayList, ValueBlock)} can be used * as appending positions one by one can be significantly slower and may not support features * like pushing RLE through the appender. */ - void append(int position, Block source); + void append(int position, ValueBlock source); /** * Creates the block from the appender data. diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java index a597983d6a19..34eab30e020e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java @@ -13,13 +13,20 @@ */ package io.trino.operator.output; +import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Int128ArrayBlock; -import io.trino.spi.type.FixedWidthType; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RowBlock; +import io.trino.spi.block.ShortArrayBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.spi.type.VariableWidthType; import io.trino.type.BlockTypeOperators; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; + +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -32,45 +39,41 @@ public PositionsAppenderFactory(BlockTypeOperators blockTypeOperators) this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); } - public PositionsAppender create(Type type, int expectedPositions, long maxPageSizeInBytes) + public UnnestingPositionsAppender create(Type type, int expectedPositions, long maxPageSizeInBytes) { - if (!type.isComparable()) { - return new UnnestingPositionsAppender(createPrimitiveAppender(type, expectedPositions, maxPageSizeInBytes)); + Optional distinctFromOperator = Optional.empty(); + if (type.isComparable()) { + distinctFromOperator = Optional.of(blockTypeOperators.getDistinctFromOperator(type)); } - - return new UnnestingPositionsAppender( - new RleAwarePositionsAppender( - blockTypeOperators.getDistinctFromOperator(type), - createPrimitiveAppender(type, expectedPositions, maxPageSizeInBytes))); + return new UnnestingPositionsAppender(createPrimitiveAppender(type, expectedPositions, maxPageSizeInBytes), distinctFromOperator); } private PositionsAppender createPrimitiveAppender(Type type, int expectedPositions, long maxPageSizeInBytes) { - if (type instanceof FixedWidthType) { - switch (((FixedWidthType) type).getFixedSize()) { - case Byte.BYTES: - return new BytePositionsAppender(expectedPositions); - case Short.BYTES: - return new ShortPositionsAppender(expectedPositions); - case Integer.BYTES: - return new IntPositionsAppender(expectedPositions); - case Long.BYTES: - return new LongPositionsAppender(expectedPositions); - case Fixed12Block.FIXED12_BYTES: - return new Fixed12PositionsAppender(expectedPositions); - case Int128ArrayBlock.INT128_BYTES: - return new Int128PositionsAppender(expectedPositions); - default: - // size not supported directly, fallback to the generic appender - } + if (type.getValueBlockType() == ByteArrayBlock.class) { + return new BytePositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == ShortArrayBlock.class) { + return new ShortPositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == IntArrayBlock.class) { + return new IntPositionsAppender(expectedPositions); } - else if (type instanceof VariableWidthType) { + if (type.getValueBlockType() == LongArrayBlock.class) { + return new LongPositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == Fixed12Block.class) { + return new Fixed12PositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == Int128ArrayBlock.class) { + return new Int128PositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == VariableWidthBlock.class) { return new SlicePositionsAppender(expectedPositions, maxPageSizeInBytes); } - else if (type instanceof RowType) { + if (type.getValueBlockType() == RowBlock.class) { return RowPositionsAppender.createRowAppender(this, (RowType) type, expectedPositions, maxPageSizeInBytes); } - return new TypedPositionsAppender(type, expectedPositions); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java index e19aaeb97401..7b113d87d429 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java @@ -26,7 +26,7 @@ public class PositionsAppenderPageBuilder { private static final int DEFAULT_INITIAL_EXPECTED_ENTRIES = 8; - private final PositionsAppender[] channelAppenders; + private final UnnestingPositionsAppender[] channelAppenders; private final int maxPageSizeInBytes; private int declaredPositions; @@ -45,7 +45,7 @@ private PositionsAppenderPageBuilder( requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); this.maxPageSizeInBytes = maxPageSizeInBytes; - channelAppenders = new PositionsAppender[types.size()]; + channelAppenders = new UnnestingPositionsAppender[types.size()]; for (int i = 0; i < channelAppenders.length; i++) { channelAppenders[i] = positionsAppenderFactory.create(types.get(i), initialExpectedEntries, maxPageSizeInBytes); } @@ -76,7 +76,7 @@ public long getRetainedSizeInBytes() // We use a foreach loop instead of streams // as it has much better performance. long retainedSizeInBytes = 0; - for (PositionsAppender positionsAppender : channelAppenders) { + for (UnnestingPositionsAppender positionsAppender : channelAppenders) { retainedSizeInBytes += positionsAppender.getRetainedSizeInBytes(); } return retainedSizeInBytes; @@ -85,13 +85,13 @@ public long getRetainedSizeInBytes() public long getSizeInBytes() { long sizeInBytes = 0; - for (PositionsAppender positionsAppender : channelAppenders) { + for (UnnestingPositionsAppender positionsAppender : channelAppenders) { sizeInBytes += positionsAppender.getSizeInBytes(); } return sizeInBytes; } - public void declarePositions(int positions) + private void declarePositions(int positions) { declaredPositions += positions; } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java index 001e60e460e4..0d1d6b642096 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java @@ -31,7 +31,7 @@ private PositionsAppenderUtil() // Copied from io.trino.spi.block.BlockUtil#calculateNewArraySize static int calculateNewArraySize(int currentSize) { - // grow array by 50% + // grow the array by 50% long newSize = (long) currentSize + (currentSize >> 1); // verify new size is within reasonable bounds diff --git a/core/trino-main/src/main/java/io/trino/operator/output/RleAwarePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/RleAwarePositionsAppender.java deleted file mode 100644 index 82480d7edce6..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/output/RleAwarePositionsAppender.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.output; - -import io.trino.spi.block.Block; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; -import it.unimi.dsi.fastutil.ints.IntArrayList; -import jakarta.annotation.Nullable; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.slice.SizeOf.instanceSize; -import static java.util.Objects.requireNonNull; - -/** - * {@link PositionsAppender} that will produce {@link RunLengthEncodedBlock} output if possible, - * that is all inputs are {@link RunLengthEncodedBlock} blocks with the same value. - */ -public class RleAwarePositionsAppender - implements PositionsAppender -{ - private static final int INSTANCE_SIZE = instanceSize(RleAwarePositionsAppender.class); - private static final int NO_RLE = -1; - - private final BlockPositionIsDistinctFrom isDistinctFromOperator; - private final PositionsAppender delegate; - - @Nullable - private Block rleValue; - - // NO_RLE means flat state, 0 means initial empty state, positive means RLE state and the current RLE position count. - private int rlePositionCount; - - public RleAwarePositionsAppender(BlockPositionIsDistinctFrom isDistinctFromOperator, PositionsAppender delegate) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - this.isDistinctFromOperator = requireNonNull(isDistinctFromOperator, "isDistinctFromOperator is null"); - } - - @Override - public void append(IntArrayList positions, Block source) - { - // RleAwarePositionsAppender should be used with UnnestingPositionsAppender that makes sure - // append is called only with flat block - checkArgument(!(source instanceof RunLengthEncodedBlock), "Append should be called with non-RLE block but got %s", source); - switchToFlat(); - delegate.append(positions, source); - } - - @Override - public void appendRle(Block value, int positionCount) - { - if (positionCount == 0) { - return; - } - checkArgument(value.getPositionCount() == 1, "Expected value to contain a single position but has %d positions".formatted(value.getPositionCount())); - - if (rlePositionCount == 0) { - // initial empty state, switch to RLE state - rleValue = value; - rlePositionCount = positionCount; - } - else if (rleValue != null) { - // we are in the RLE state - if (!isDistinctFromOperator.isDistinctFrom(rleValue, 0, value, 0)) { - // the values match. we can just add positions. - this.rlePositionCount += positionCount; - return; - } - // RLE values do not match. switch to flat state - switchToFlat(); - delegate.appendRle(value, positionCount); - } - else { - // flat state - delegate.appendRle(value, positionCount); - } - } - - @Override - public void append(int position, Block value) - { - switchToFlat(); - delegate.append(position, value); - } - - @Override - public Block build() - { - Block result; - if (rleValue != null) { - result = RunLengthEncodedBlock.create(rleValue, rlePositionCount); - } - else { - result = delegate.build(); - } - - reset(); - return result; - } - - private void reset() - { - rleValue = null; - rlePositionCount = 0; - } - - @Override - public long getRetainedSizeInBytes() - { - long retainedRleSize = rleValue != null ? rleValue.getRetainedSizeInBytes() : 0; - return INSTANCE_SIZE + retainedRleSize + delegate.getRetainedSizeInBytes(); - } - - @Override - public long getSizeInBytes() - { - long rleSize = rleValue != null ? rleValue.getSizeInBytes() : 0; - return rleSize + delegate.getSizeInBytes(); - } - - private void switchToFlat() - { - if (rleValue != null) { - // we are in the RLE state, flatten all RLE blocks - delegate.appendRle(rleValue, rlePositionCount); - rleValue = null; - } - rlePositionCount = NO_RLE; - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java index 84634f3ef66e..da334b7b5506 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.RowType; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -23,6 +24,7 @@ import java.util.List; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -34,7 +36,7 @@ public class RowPositionsAppender implements PositionsAppender { private static final int INSTANCE_SIZE = instanceSize(RowPositionsAppender.class); - private final PositionsAppender[] fieldAppenders; + private final UnnestingPositionsAppender[] fieldAppenders; private int initialEntryCount; private boolean initialized; @@ -51,14 +53,14 @@ public static RowPositionsAppender createRowAppender( int expectedPositions, long maxPageSizeInBytes) { - PositionsAppender[] fields = new PositionsAppender[type.getFields().size()]; + UnnestingPositionsAppender[] fields = new UnnestingPositionsAppender[type.getFields().size()]; for (int i = 0; i < fields.length; i++) { fields[i] = positionsAppenderFactory.create(type.getFields().get(i).getType(), expectedPositions, maxPageSizeInBytes); } return new RowPositionsAppender(fields, expectedPositions); } - private RowPositionsAppender(PositionsAppender[] fieldAppenders, int expectedPositions) + private RowPositionsAppender(UnnestingPositionsAppender[] fieldAppenders, int expectedPositions) { this.fieldAppenders = requireNonNull(fieldAppenders, "fields is null"); this.initialEntryCount = expectedPositions; @@ -66,39 +68,30 @@ private RowPositionsAppender(PositionsAppender[] fieldAppenders, int expectedPos } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof RowBlock, "Block must be instance of %s", RowBlock.class); + if (positions.isEmpty()) { return; } ensureCapacity(positions.size()); - if (block instanceof RowBlock sourceRowBlock) { - IntArrayList nonNullPositions; - if (sourceRowBlock.mayHaveNull()) { - nonNullPositions = processNullablePositions(positions, sourceRowBlock); - hasNullRow |= nonNullPositions.size() < positions.size(); - hasNonNullRow |= nonNullPositions.size() > 0; - } - else { - // the source Block does not have nulls - nonNullPositions = processNonNullablePositions(positions, sourceRowBlock); - hasNonNullRow = true; - } - - List fieldBlocks = sourceRowBlock.getChildren(); - for (int i = 0; i < fieldAppenders.length; i++) { - fieldAppenders[i].append(nonNullPositions, fieldBlocks.get(i)); - } - } - else if (allPositionsNull(positions, block)) { - // all input positions are null. We can handle that even if block type is not RowBLock. - // append positions.size() nulls - Arrays.fill(rowIsNull, positionCount, positionCount + positions.size(), true); - hasNullRow = true; + RowBlock sourceRowBlock = (RowBlock) block; + IntArrayList nonNullPositions; + if (sourceRowBlock.mayHaveNull()) { + nonNullPositions = processNullablePositions(positions, sourceRowBlock); + hasNullRow |= nonNullPositions.size() < positions.size(); + hasNonNullRow |= !nonNullPositions.isEmpty(); } else { - throw new IllegalArgumentException("unsupported block type: " + block); + // the source Block does not have nulls + nonNullPositions = processNonNullablePositions(positions, sourceRowBlock); + hasNonNullRow = true; + } + + List fieldBlocks = sourceRowBlock.getChildren(); + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i].append(nonNullPositions, fieldBlocks.get(i)); } positionCount += positions.size(); @@ -106,62 +99,49 @@ else if (allPositionsNull(positions, block)) { } @Override - public void appendRle(Block value, int rlePositionCount) + public void appendRle(ValueBlock value, int rlePositionCount) { + checkArgument(value instanceof RowBlock, "Block must be instance of %s", RowBlock.class); + ensureCapacity(rlePositionCount); - if (value instanceof RowBlock sourceRowBlock) { - if (sourceRowBlock.isNull(0)) { - // append rlePositionCount nulls - Arrays.fill(rowIsNull, positionCount, positionCount + rlePositionCount, true); - hasNullRow = true; - } - else { - // append not null row value - List fieldBlocks = sourceRowBlock.getChildren(); - int fieldPosition = sourceRowBlock.getFieldBlockOffset(0); - for (int i = 0; i < fieldAppenders.length; i++) { - fieldAppenders[i].appendRle(fieldBlocks.get(i).getSingleValueBlock(fieldPosition), rlePositionCount); - } - hasNonNullRow = true; - } - } - else if (value.isNull(0)) { + RowBlock sourceRowBlock = (RowBlock) value; + if (sourceRowBlock.isNull(0)) { // append rlePositionCount nulls Arrays.fill(rowIsNull, positionCount, positionCount + rlePositionCount, true); hasNullRow = true; } else { - throw new IllegalArgumentException("unsupported block type: " + value); + // append not null row value + List fieldBlocks = sourceRowBlock.getChildren(); + int fieldPosition = sourceRowBlock.getFieldBlockOffset(0); + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i].appendRle(fieldBlocks.get(i).getSingleValueBlock(fieldPosition), rlePositionCount); + } + hasNonNullRow = true; } positionCount += rlePositionCount; resetSize(); } @Override - public void append(int position, Block value) + public void append(int position, ValueBlock value) { + checkArgument(value instanceof RowBlock, "Block must be instance of %s", RowBlock.class); + ensureCapacity(1); - if (value instanceof RowBlock sourceRowBlock) { - if (sourceRowBlock.isNull(position)) { - rowIsNull[positionCount] = true; - hasNullRow = true; - } - else { - // append not null row value - List fieldBlocks = sourceRowBlock.getChildren(); - int fieldPosition = sourceRowBlock.getFieldBlockOffset(position); - for (int i = 0; i < fieldAppenders.length; i++) { - fieldAppenders[i].append(fieldPosition, fieldBlocks.get(i)); - } - hasNonNullRow = true; - } - } - else if (value.isNull(position)) { + RowBlock sourceRowBlock = (RowBlock) value; + if (sourceRowBlock.isNull(position)) { rowIsNull[positionCount] = true; hasNullRow = true; } else { - throw new IllegalArgumentException("unsupported block type: " + value); + // append not null row value + List fieldBlocks = sourceRowBlock.getChildren(); + int fieldPosition = sourceRowBlock.getFieldBlockOffset(position); + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i].append(fieldPosition, fieldBlocks.get(i)); + } + hasNonNullRow = true; } positionCount++; resetSize(); @@ -195,7 +175,7 @@ public long getRetainedSizeInBytes() } long size = INSTANCE_SIZE + sizeOf(rowIsNull); - for (PositionsAppender field : fieldAppenders) { + for (UnnestingPositionsAppender field : fieldAppenders) { size += field.getRetainedSizeInBytes(); } @@ -211,7 +191,7 @@ public long getSizeInBytes() } long size = (Integer.BYTES + Byte.BYTES) * (long) positionCount; - for (PositionsAppender field : fieldAppenders) { + for (UnnestingPositionsAppender field : fieldAppenders) { size += field.getSizeInBytes(); } @@ -230,16 +210,6 @@ private void reset() resetSize(); } - private boolean allPositionsNull(IntArrayList positions, Block block) - { - for (int i = 0; i < positions.size(); i++) { - if (!block.isNull(positions.getInt(i))) { - return false; - } - } - return true; - } - private IntArrayList processNullablePositions(IntArrayList positions, RowBlock sourceRowBlock) { int[] nonNullPositions = new int[positions.size()]; @@ -256,7 +226,7 @@ private IntArrayList processNullablePositions(IntArrayList positions, RowBlock s return IntArrayList.wrap(nonNullPositions, nonNullPositionsCount); } - private IntArrayList processNonNullablePositions(IntArrayList positions, RowBlock sourceRowBlock) + private static IntArrayList processNonNullablePositions(IntArrayList positions, RowBlock sourceRowBlock) { int[] nonNullPositions = new int[positions.size()]; for (int i = 0; i < positions.size(); i++) { diff --git a/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java index 21afc3a700bc..acc9f9f23159 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.ShortArrayBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public ShortPositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof ShortArrayBlock, "Block must be instance of %s", ShortArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof ShortArrayBlock, "Block must be instance of %s", ShortArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof ShortArrayBlock, "Block must be instance of %s", ShortArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java index 058e5e49a19a..638fc54f3b16 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java @@ -32,11 +32,11 @@ public SkewedPartitionFunction(PartitionFunction partitionFunction, SkewedPartit this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.skewedPartitionRebalancer = requireNonNull(skewedPartitionRebalancer, "skewedPartitionRebalancer is null"); - this.partitionRowCount = new long[partitionFunction.getPartitionCount()]; + this.partitionRowCount = new long[partitionFunction.partitionCount()]; } @Override - public int getPartitionCount() + public int partitionCount() { return skewedPartitionRebalancer.getTaskCount(); } @@ -50,7 +50,7 @@ public int getPartition(Page page, int position) public void flushPartitionRowCountToRebalancer() { - for (int partition = 0; partition < partitionFunction.getPartitionCount(); partition++) { + for (int partition = 0; partition < partitionFunction.partitionCount(); partition++) { skewedPartitionRebalancer.addPartitionRowCount(partition, partitionRowCount[partition]); partitionRowCount[partition] = 0; } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java index 7849f6c61501..1d5d5d64ffd3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java @@ -18,12 +18,14 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.instanceSize; @@ -51,7 +53,7 @@ public class SlicePositionsAppender private boolean hasNullValue; private boolean hasNonNullValue; - // it is assumed that the offsets array is one position longer than the valueIsNull array + // it is assumed that the offset array is one position longer than the valueIsNull array private boolean[] valueIsNull = new boolean[0]; private int[] offsets = new int[1]; @@ -74,54 +76,53 @@ public SlicePositionsAppender(int expectedEntries, int expectedBytes) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof VariableWidthBlock, "Block must be instance of %s", VariableWidthBlock.class); + if (positions.isEmpty()) { return; } ensurePositionCapacity(positionCount + positions.size()); - if (block instanceof VariableWidthBlock variableWidthBlock) { - int newByteCount = 0; - int[] lengths = new int[positions.size()]; - int[] sourceOffsets = new int[positions.size()]; - int[] positionArray = positions.elements(); - - if (block.mayHaveNull()) { - for (int i = 0; i < positions.size(); i++) { - int position = positionArray[i]; - int length = variableWidthBlock.getSliceLength(position); - lengths[i] = length; - sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); - newByteCount += length; - boolean isNull = block.isNull(position); - valueIsNull[positionCount + i] = isNull; - offsets[positionCount + i + 1] = offsets[positionCount + i] + length; - hasNullValue |= isNull; - hasNonNullValue |= !isNull; - } - } - else { - for (int i = 0; i < positions.size(); i++) { - int position = positionArray[i]; - int length = variableWidthBlock.getSliceLength(position); - lengths[i] = length; - sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); - newByteCount += length; - offsets[positionCount + i + 1] = offsets[positionCount + i] + length; - } - hasNonNullValue = true; + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; + int newByteCount = 0; + int[] lengths = new int[positions.size()]; + int[] sourceOffsets = new int[positions.size()]; + int[] positionArray = positions.elements(); + + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + int length = variableWidthBlock.getSliceLength(position); + lengths[i] = length; + sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); + newByteCount += length; + boolean isNull = block.isNull(position); + valueIsNull[positionCount + i] = isNull; + offsets[positionCount + i + 1] = offsets[positionCount + i] + length; + hasNullValue |= isNull; + hasNonNullValue |= !isNull; } - copyBytes(variableWidthBlock.getRawSlice(), lengths, sourceOffsets, positions.size(), newByteCount); } else { - appendGenericBlock(positions, block); + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + int length = variableWidthBlock.getSliceLength(position); + lengths[i] = length; + sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); + newByteCount += length; + offsets[positionCount + i + 1] = offsets[positionCount + i] + length; + } + hasNonNullValue = true; } + copyBytes(variableWidthBlock.getRawSlice(), lengths, sourceOffsets, positions.size(), newByteCount); } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof VariableWidthBlock, "Block must be instance of %s", VariableWidthBlock.class); + if (rlePositionCount == 0) { return; } @@ -141,8 +142,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int position, Block source) + public void append(int position, ValueBlock source) { + checkArgument(source instanceof VariableWidthBlock, "Block must be instance of %s but is %s".formatted(VariableWidthBlock.class, source.getClass())); + ensurePositionCapacity(positionCount + 1); if (source.isNull(position)) { valueIsNull[positionCount] = true; @@ -259,30 +262,6 @@ static void duplicateBytes(Slice slice, byte[] bytes, int startOffset, int count System.arraycopy(bytes, startOffset, bytes, startOffset + duplicatedBytes, totalDuplicatedBytes - duplicatedBytes); } - private void appendGenericBlock(IntArrayList positions, Block block) - { - int newByteCount = 0; - for (int i = 0; i < positions.size(); i++) { - int position = positions.getInt(i); - if (block.isNull(position)) { - offsets[positionCount + 1] = offsets[positionCount]; - valueIsNull[positionCount] = true; - hasNullValue = true; - } - else { - int length = block.getSliceLength(position); - ensureExtraBytesCapacity(length); - Slice slice = block.getSlice(position, 0, length); - slice.getBytes(0, bytes, offsets[positionCount], length); - offsets[positionCount + 1] = offsets[positionCount] + length; - hasNonNullValue = true; - newByteCount += length; - } - positionCount++; - } - updateSize(positions.size(), newByteCount); - } - private void reset() { initialEntryCount = calculateBlockResetSize(positionCount); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java b/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java index 49b9e87595ec..b687ed09ac74 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java @@ -134,7 +134,7 @@ public boolean isFinished() @Override public ListenableFuture isBlocked() { - // Avoid re-synchronizing on the output buffer when operator is already blocked + // Avoid re-synchronizing on the output buffer when the operator is already blocked if (isBlocked.isDone()) { isBlocked = outputBuffer.isFull(); if (isBlocked.isDone()) { diff --git a/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java index 1f66dd05d0dc..9d8ff32d4478 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -30,20 +31,13 @@ class TypedPositionsAppender private BlockBuilder blockBuilder; TypedPositionsAppender(Type type, int expectedPositions) - { - this( - type, - type.createBlockBuilder(null, expectedPositions)); - } - - TypedPositionsAppender(Type type, BlockBuilder blockBuilder) { this.type = requireNonNull(type, "type is null"); - this.blockBuilder = requireNonNull(blockBuilder, "blockBuilder is null"); + this.blockBuilder = type.createBlockBuilder(null, expectedPositions); } @Override - public void append(IntArrayList positions, Block source) + public void append(IntArrayList positions, ValueBlock source) { int[] positionArray = positions.elements(); for (int i = 0; i < positions.size(); i++) { @@ -52,7 +46,7 @@ public void append(IntArrayList positions, Block source) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { for (int i = 0; i < rlePositionCount; i++) { type.appendTo(block, 0, blockBuilder); @@ -60,7 +54,7 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int position, Block source) + public void append(int position, ValueBlock source) { type.appendTo(source, position, blockBuilder); } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java index cedc3ece1cfd..c360c9b7cb8f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java @@ -16,11 +16,18 @@ import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntArrays; +import jakarta.annotation.Nullable; + +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; import static io.trino.operator.output.PositionsAppenderUtil.calculateNewArraySize; import static java.lang.Math.max; @@ -30,52 +37,110 @@ * Dispatches the {@link #append} and {@link #appendRle} methods to the {@link #delegate} depending on the input {@link Block} class. */ public class UnnestingPositionsAppender - implements PositionsAppender { private static final int INSTANCE_SIZE = instanceSize(UnnestingPositionsAppender.class); + // The initial state will transition to either the DICTIONARY or RLE state, and from there to the DIRECT state if necessary. + private enum State + { + UNINITIALIZED, DICTIONARY, RLE, DIRECT + } + private final PositionsAppender delegate; - private DictionaryBlockBuilder dictionaryBlockBuilder; + @Nullable + private final BlockPositionIsDistinctFrom isDistinctFromOperator; - public UnnestingPositionsAppender(PositionsAppender delegate) + private State state = State.UNINITIALIZED; + + private ValueBlock dictionary; + private DictionaryIdsBuilder dictionaryIdsBuilder; + + @Nullable + private ValueBlock rleValue; + private int rlePositionCount; + + public UnnestingPositionsAppender(PositionsAppender delegate, Optional isDistinctFromOperator) { this.delegate = requireNonNull(delegate, "delegate is null"); - this.dictionaryBlockBuilder = new DictionaryBlockBuilder(); + this.dictionaryIdsBuilder = new DictionaryIdsBuilder(1024); + this.isDistinctFromOperator = isDistinctFromOperator.orElse(null); } - @Override public void append(IntArrayList positions, Block source) { if (positions.isEmpty()) { return; } - if (source instanceof RunLengthEncodedBlock) { - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.appendRle(((RunLengthEncodedBlock) source).getValue(), positions.size()); + + if (source instanceof RunLengthEncodedBlock rleBlock) { + appendRle(rleBlock.getValue(), positions.size()); + } + else if (source instanceof DictionaryBlock dictionaryBlock) { + ValueBlock dictionary = dictionaryBlock.getDictionary(); + if (state == State.UNINITIALIZED) { + state = State.DICTIONARY; + this.dictionary = dictionary; + dictionaryIdsBuilder.appendPositions(positions, dictionaryBlock); + } + else if (state == State.DICTIONARY && this.dictionary == dictionary) { + dictionaryIdsBuilder.appendPositions(positions, dictionaryBlock); + } + else { + transitionToDirect(); + + int[] positionArray = new int[positions.size()]; + for (int i = 0; i < positions.size(); i++) { + positionArray[i] = dictionaryBlock.getId(positions.getInt(i)); + } + delegate.append(IntArrayList.wrap(positionArray), dictionary); + } } - else if (source instanceof DictionaryBlock) { - appendDictionary(positions, (DictionaryBlock) source); + else if (source instanceof ValueBlock valueBlock) { + transitionToDirect(); + delegate.append(positions, valueBlock); } else { - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.append(positions, source); + throw new IllegalArgumentException("Unsupported block type: " + source.getClass().getSimpleName()); } } - @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock value, int positionCount) { - if (rlePositionCount == 0) { + if (positionCount == 0) { return; } - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.appendRle(block, rlePositionCount); + + if (state == State.DICTIONARY) { + transitionToDirect(); + } + if (isDistinctFromOperator == null) { + transitionToDirect(); + } + + if (state == State.UNINITIALIZED) { + state = State.RLE; + rleValue = value; + rlePositionCount = positionCount; + return; + } + if (state == State.RLE) { + if (!isDistinctFromOperator.isDistinctFrom(rleValue, 0, value, 0)) { + // the values match. we can just add positions. + rlePositionCount += positionCount; + return; + } + transitionToDirect(); + } + + verify(state == State.DIRECT); + delegate.appendRle(value, positionCount); } - @Override public void append(int position, Block source) { - dictionaryBlockBuilder.flushDictionary(delegate); + if (state != State.DIRECT) { + transitionToDirect(); + } if (source instanceof RunLengthEncodedBlock runLengthEncodedBlock) { delegate.append(0, runLengthEncodedBlock.getValue()); @@ -83,134 +148,108 @@ public void append(int position, Block source) else if (source instanceof DictionaryBlock dictionaryBlock) { delegate.append(dictionaryBlock.getId(position), dictionaryBlock.getDictionary()); } + else if (source instanceof ValueBlock valueBlock) { + delegate.append(position, valueBlock); + } else { - delegate.append(position, source); + throw new IllegalArgumentException("Unsupported block type: " + source.getClass().getSimpleName()); } } - @Override - public Block build() + private void transitionToDirect() { - Block result; - if (dictionaryBlockBuilder.isEmpty()) { - result = delegate.build(); + if (state == State.DICTIONARY) { + int[] dictionaryIds = dictionaryIdsBuilder.getDictionaryIds(); + delegate.append(IntArrayList.wrap(dictionaryIds, dictionaryIdsBuilder.size()), dictionary); + dictionary = null; + dictionaryIdsBuilder = dictionaryIdsBuilder.newBuilderLike(); } - else { - result = dictionaryBlockBuilder.build(); + else if (state == State.RLE) { + delegate.appendRle(rleValue, rlePositionCount); + rleValue = null; + rlePositionCount = 0; } - dictionaryBlockBuilder = dictionaryBlockBuilder.newBuilderLike(); - return result; + state = State.DIRECT; } - @Override - public long getRetainedSizeInBytes() + public Block build() { - return INSTANCE_SIZE + delegate.getRetainedSizeInBytes() + dictionaryBlockBuilder.getRetainedSizeInBytes(); - } + Block result = switch (state) { + case DICTIONARY -> DictionaryBlock.create(dictionaryIdsBuilder.size(), dictionary, dictionaryIdsBuilder.getDictionaryIds()); + case RLE -> RunLengthEncodedBlock.create(rleValue, rlePositionCount); + case UNINITIALIZED, DIRECT -> delegate.build(); + }; - @Override - public long getSizeInBytes() - { - return delegate.getSizeInBytes(); + state = State.UNINITIALIZED; + dictionary = null; + dictionaryIdsBuilder = dictionaryIdsBuilder.newBuilderLike(); + rleValue = null; + rlePositionCount = 0; + + return result; } - private void appendDictionary(IntArrayList positions, DictionaryBlock source) + public long getRetainedSizeInBytes() { - Block dictionary = source.getDictionary(); - if (dictionary instanceof RunLengthEncodedBlock rleDictionary) { - appendRle(rleDictionary.getValue(), positions.size()); - return; - } - - IntArrayList dictionaryPositions = getDictionaryPositions(positions, source); - if (dictionaryBlockBuilder.canAppend(dictionary)) { - dictionaryBlockBuilder.append(dictionaryPositions, dictionary); - } - else { - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.append(dictionaryPositions, dictionary); - } + return INSTANCE_SIZE + + delegate.getRetainedSizeInBytes() + + dictionaryIdsBuilder.getRetainedSizeInBytes() + + (rleValue != null ? rleValue.getRetainedSizeInBytes() : 0); } - private IntArrayList getDictionaryPositions(IntArrayList positions, DictionaryBlock block) + public long getSizeInBytes() { - int[] positionArray = new int[positions.size()]; - for (int i = 0; i < positions.size(); i++) { - positionArray[i] = block.getId(positions.getInt(i)); - } - return IntArrayList.wrap(positionArray); + return delegate.getSizeInBytes() + + // dictionary size is not included due to the expense of the calculation + (rleValue != null ? rleValue.getSizeInBytes() : 0); } - private static class DictionaryBlockBuilder + private static class DictionaryIdsBuilder { - private static final int INSTANCE_SIZE = instanceSize(DictionaryBlockBuilder.class); + private static final int INSTANCE_SIZE = instanceSize(DictionaryIdsBuilder.class); + private final int initialEntryCount; - private Block dictionary; private int[] dictionaryIds; - private int positionCount; - private boolean closed; - - public DictionaryBlockBuilder() - { - this(1024); - } + private int size; - public DictionaryBlockBuilder(int initialEntryCount) + public DictionaryIdsBuilder(int initialEntryCount) { this.initialEntryCount = initialEntryCount; this.dictionaryIds = new int[0]; } - public boolean isEmpty() + public int[] getDictionaryIds() { - return positionCount == 0; + return dictionaryIds; } - public Block build() + public int size() { - return DictionaryBlock.create(positionCount, dictionary, dictionaryIds); + return size; } public long getRetainedSizeInBytes() { - return INSTANCE_SIZE - + (long) dictionaryIds.length * Integer.BYTES - + (dictionary != null ? dictionary.getRetainedSizeInBytes() : 0); + return INSTANCE_SIZE + sizeOf(dictionaryIds); } - public boolean canAppend(Block dictionary) + public void appendPositions(IntArrayList positions, DictionaryBlock block) { - return !closed && (dictionary == this.dictionary || this.dictionary == null); - } + checkArgument(!positions.isEmpty(), "positions is empty"); + ensureCapacity(size + positions.size()); - public void append(IntArrayList mappedPositions, Block dictionary) - { - checkArgument(canAppend(dictionary)); - this.dictionary = dictionary; - ensureCapacity(positionCount + mappedPositions.size()); - System.arraycopy(mappedPositions.elements(), 0, dictionaryIds, positionCount, mappedPositions.size()); - positionCount += mappedPositions.size(); - } - - public void flushDictionary(PositionsAppender delegate) - { - if (closed) { - return; - } - if (positionCount > 0) { - requireNonNull(dictionary, () -> "dictionary is null but we have pending dictionaryIds " + positionCount); - delegate.append(IntArrayList.wrap(dictionaryIds, positionCount), dictionary); + for (int i = 0; i < positions.size(); i++) { + dictionaryIds[size + i] = block.getId(positions.getInt(i)); } - - closed = true; - dictionaryIds = new int[0]; - positionCount = 0; - dictionary = null; + size += positions.size(); } - public DictionaryBlockBuilder newBuilderLike() + public DictionaryIdsBuilder newBuilderLike() { - return new DictionaryBlockBuilder(max(calculateBlockResetSize(positionCount), initialEntryCount)); + if (size == 0) { + return this; + } + return new DictionaryIdsBuilder(max(calculateBlockResetSize(size), initialEntryCount)); } private void ensureCapacity(int capacity) @@ -226,9 +265,9 @@ private void ensureCapacity(int capacity) else { newSize = initialEntryCount; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); - dictionaryIds = IntArrays.ensureCapacity(dictionaryIds, newSize, positionCount); + dictionaryIds = IntArrays.ensureCapacity(dictionaryIds, newSize, size); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java index 4b386470415e..e0332a537189 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java @@ -18,6 +18,7 @@ import io.trino.spi.block.MapBlock; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.Convention; import io.trino.spi.function.Description; import io.trino.spi.function.OperatorDependency; @@ -30,8 +31,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -53,7 +54,7 @@ public static SqlMap arrayHistogram( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", @@ -61,19 +62,20 @@ public static SqlMap arrayHistogram( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock, @TypeParameter("map(T, bigint)") MapType mapType, @SqlType("array(T)") Block arrayBlock) { TypedHistogram histogram = new TypedHistogram(elementType, readFlat, writeFlat, hashFlat, distinctFlatBlock, hashBlock, false); - int positionCount = arrayBlock.getPositionCount(); - for (int position = 0; position < positionCount; position++) { + ValueBlock valueBlock = arrayBlock.getUnderlyingValueBlock(); + for (int i = 0; i < arrayBlock.getPositionCount(); i++) { + int position = arrayBlock.getUnderlyingValuePosition(i); if (!arrayBlock.isNull(position)) { - histogram.add(0, arrayBlock, position, 1L); + histogram.add(0, valueBlock, position, 1L); } } MapBlockBuilder blockBuilder = mapType.createBlockBuilder(null, histogram.size()); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java index 8db7d8638c12..daf0bd6f29fc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java @@ -208,6 +208,10 @@ private static int computeScore(InvocationConvention callingConvention) case BLOCK_POSITION: score += 1000; break; + case VALUE_BLOCK_POSITION_NOT_NULL: + case VALUE_BLOCK_POSITION: + score += 2000; + break; case IN_OUT: score += 10_000; break; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java index bbf3c42aa355..2ced4b20154c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java @@ -26,6 +26,7 @@ import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction.ScalarImplementationChoice; import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -59,7 +60,7 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.ImmutableSortedSet.toImmutableSortedSet; @@ -82,6 +83,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; @@ -255,6 +258,11 @@ private static MethodType javaMethodType(ParametricScalarImplementationChoice ch methodHandleParameterTypes.add(Block.class); methodHandleParameterTypes.add(int.class); break; + case VALUE_BLOCK_POSITION: + case VALUE_BLOCK_POSITION_NOT_NULL: + methodHandleParameterTypes.add(ValueBlock.class); + methodHandleParameterTypes.add(int.class); + break; case IN_OUT: methodHandleParameterTypes.add(InOut.class); break; @@ -599,15 +607,21 @@ private void parseArguments(Method method, Signature.Builder signatureBuilder, L else { // value type InvocationArgumentConvention argumentConvention; + boolean nullable = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance); if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) { - checkState(method.getParameterCount() > (parameterIndex + 1)); - checkState(parameterType == Block.class); + verify(method.getParameterCount() > (parameterIndex + 1)); - argumentConvention = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance) ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL; + if (parameterType == Block.class) { + argumentConvention = nullable ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL; + } + else { + verify(ValueBlock.class.isAssignableFrom(parameterType)); + argumentConvention = nullable ? VALUE_BLOCK_POSITION : VALUE_BLOCK_POSITION_NOT_NULL; + } Annotation[] parameterAnnotations = method.getParameterAnnotations()[parameterIndex + 1]; - checkState(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); + verify(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); } - else if (Stream.of(annotations).anyMatch(SqlNullable.class::isInstance)) { + else if (nullable) { checkCondition(!parameterType.isPrimitive(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); argumentConvention = BOXED_NULLABLE; @@ -641,7 +655,7 @@ else if (parameterType.equals(InOut.class)) { } } - if (argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL) { + if (argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == VALUE_BLOCK_POSITION || argumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { argumentNativeContainerTypes.add(Optional.of(type.nativeContainerType())); } else { diff --git a/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java b/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java index 2d19dc6bc13c..eb56a954aeaf 100644 --- a/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java +++ b/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java @@ -79,7 +79,7 @@ public GenericPartitioningSpiller( requireNonNull(memoryContext, "memoryContext is null"); closer.register(memoryContext::close); this.memoryContext = memoryContext; - int partitionCount = partitionFunction.getPartitionCount(); + int partitionCount = partitionFunction.partitionCount(); ImmutableList.Builder pageBuilders = ImmutableList.builder(); spillers = new ArrayList<>(partitionCount); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index dee0abbde2fc..afdcfa920950 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -4328,8 +4328,12 @@ private void analyzeSelectAllColumns( .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, allColumns, "Unable to resolve reference %s", prefix)); if (identifierChainBasis.getBasisType() == TABLE) { RelationType relationType = identifierChainBasis.getRelationType().orElseThrow(); - List fields = filterInaccessibleFields(relationType.resolveVisibleFieldsWithRelationPrefix(Optional.of(prefix))); + List requestedFields = relationType.resolveVisibleFieldsWithRelationPrefix(Optional.of(prefix)); + List fields = filterInaccessibleFields(requestedFields); if (fields.isEmpty()) { + if (!requestedFields.isEmpty()) { + throw semanticException(TABLE_NOT_FOUND, allColumns, "Relation not found or not allowed"); + } throw semanticException(COLUMN_NOT_FOUND, allColumns, "SELECT * not allowed from relation that has no columns"); } boolean local = scope.isLocalScope(identifierChainBasis.getScope().orElseThrow()); @@ -4354,11 +4358,15 @@ private void analyzeSelectAllColumns( throw semanticException(NOT_SUPPORTED, allColumns, "Column aliases not supported"); } - List fields = filterInaccessibleFields((List) scope.getRelationType().getVisibleFields()); + List requestedFields = (List) scope.getRelationType().getVisibleFields(); + List fields = filterInaccessibleFields(requestedFields); if (fields.isEmpty()) { if (node.getFrom().isEmpty()) { throw semanticException(COLUMN_NOT_FOUND, allColumns, "SELECT * not allowed in queries without FROM clause"); } + if (!requestedFields.isEmpty()) { + throw semanticException(TABLE_NOT_FOUND, allColumns, "Relation not found or not allowed"); + } throw semanticException(COLUMN_NOT_FOUND, allColumns, "SELECT * not allowed from relation that has no columns"); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java index 43cde586607a..da90def6a9a7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java @@ -22,6 +22,7 @@ public class CompilerConfig { private int expressionCacheSize = 10_000; + private boolean specializeAggregationLoops = true; @Min(0) public int getExpressionCacheSize() @@ -36,4 +37,16 @@ public CompilerConfig setExpressionCacheSize(int expressionCacheSize) this.expressionCacheSize = expressionCacheSize; return this; } + + public boolean isSpecializeAggregationLoops() + { + return specializeAggregationLoops; + } + + @Config("compiler.specialized-aggregation-loops") + public CompilerConfig setSpecializeAggregationLoops(boolean specializeAggregationLoops) + { + this.specializeAggregationLoops = specializeAggregationLoops; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index c5c8f6f7f266..f0286d52289e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -443,6 +443,7 @@ public class LocalExecutionPlanner private final ExchangeManagerRegistry exchangeManagerRegistry; private final PositionsAppenderFactory positionsAppenderFactory; private final NodeVersion version; + private final boolean specializeAggregationLoops; private final NonEvictableCache accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder() .maximumSize(1000) @@ -477,7 +478,8 @@ public LocalExecutionPlanner( TypeOperators typeOperators, TableExecuteContextManager tableExecuteContextManager, ExchangeManagerRegistry exchangeManagerRegistry, - NodeVersion version) + NodeVersion version, + CompilerConfig compilerConfig) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.metadata = plannerContext.getMetadata(); @@ -524,6 +526,7 @@ public LocalExecutionPlanner( this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.positionsAppenderFactory = new PositionsAppenderFactory(blockTypeOperators); this.version = requireNonNull(version, "version is null"); + this.specializeAggregationLoops = compilerConfig.isSpecializeAggregationLoops(); } public LocalExecutionPlan plan( @@ -589,7 +592,7 @@ public LocalExecutionPlan plan( // Keep the task bucket count to 50% of total local writers int taskBucketCount = (int) ceil(0.5 * partitionedWriterCount); skewedPartitionRebalancer = Optional.of(new SkewedPartitionRebalancer( - partitionFunction.getPartitionCount(), + partitionFunction.partitionCount(), taskCount, taskBucketCount, getWriterScalingMinDataProcessed(taskContext.getSession()).toBytes(), @@ -3822,7 +3825,8 @@ private AggregatorFactory buildAggregatorFactory( () -> generateAccumulatorFactory( resolvedFunction.getSignature(), aggregationImplementation, - resolvedFunction.getFunctionNullability())); + resolvedFunction.getFunctionNullability(), + specializeAggregationLoops)); if (aggregation.isDistinct()) { accumulatorFactory = new DistinctAccumulatorFactory( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java index 2b9608defade..95f522c66606 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java @@ -199,13 +199,13 @@ public MergePartitionFunction(PartitionFunction insertFunction, PartitionFunctio this.updateFunction = requireNonNull(updateFunction, "updateFunction is null"); this.insertColumns = requireNonNull(insertColumns, "insertColumns is null"); this.updateColumns = requireNonNull(updateColumns, "updateColumns is null"); - checkArgument(insertFunction.getPartitionCount() == updateFunction.getPartitionCount(), "partition counts must match"); + checkArgument(insertFunction.partitionCount() == updateFunction.partitionCount(), "partition counts must match"); } @Override - public int getPartitionCount() + public int partitionCount() { - return insertFunction.getPartitionCount(); + return insertFunction.partitionCount(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index c2960e5cc630..bee7e1aba871 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -549,7 +549,7 @@ public PlanNode plan(Delete node) assignmentsBuilder.putIdentity(symbol); } else { - assignmentsBuilder.put(symbol, new NullLiteral()); + assignmentsBuilder.put(symbol, new Cast(new NullLiteral(), toSqlType(symbolAllocator.getTypes().get(symbol)))); } } List columnSymbols = columnSymbolsBuilder.build(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java index 4d9656b37094..3ed9583276f3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java @@ -268,7 +268,7 @@ private PlanBuilder planScalarSubquery(PlanBuilder subPlan, Cluster BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$current_time") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); case LOCALTIME -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$localtime") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); case TIMESTAMP -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$current_timestamp") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); case LOCALTIMESTAMP -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$localtimestamp") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); }; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java index 9360b029254f..9a16f1256b85 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java @@ -23,12 +23,15 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.ValuesNode; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Row; import java.util.Optional; import static com.google.common.base.Verify.verify; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.Patterns.Values.rowCount; import static io.trino.sql.planner.plan.Patterns.tableFinish; import static io.trino.sql.planner.plan.Patterns.values; @@ -86,7 +89,7 @@ public Result apply(TableFinishNode finishNode, Captures captures, Context conte new ValuesNode( finishNode.getId(), finishNode.getOutputSymbols(), - ImmutableList.of(new Row(ImmutableList.of(new NullLiteral()))))); + ImmutableList.of(new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT))))))); } private Optional getSingleSourceSkipExchange(PlanNode node, Lookup lookup) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index 42de46b02c05..79d00184f2c0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -42,6 +42,7 @@ import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; +import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT; import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation; import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter; @@ -123,7 +124,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co correlatedJoinNode.getInput(), rewrittenSubquery, correlatedJoinNode.getCorrelation(), - producesSingleRow ? correlatedJoinNode.getType() : LEFT, + producesSingleRow ? INNER : correlatedJoinNode.getType(), correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java index 896a5a31f785..ea7a5458e893 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java @@ -19,12 +19,14 @@ import com.google.common.collect.Sets; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.spi.type.Type; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.Expression; import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.NullLiteral; @@ -33,6 +35,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.matching.Pattern.empty; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.FULL; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT; @@ -91,7 +94,8 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co for (Symbol inputSymbol : Sets.intersection( ImmutableSet.copyOf(correlatedJoinNode.getInput().getOutputSymbols()), ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()))) { - assignments.put(inputSymbol, new IfExpression(correlatedJoinNode.getFilter(), inputSymbol.toSymbolReference(), new NullLiteral())); + Type inputType = context.getSymbolAllocator().getTypes().get(inputSymbol); + assignments.put(inputSymbol, new IfExpression(correlatedJoinNode.getFilter(), inputSymbol.toSymbolReference(), new Cast(new NullLiteral(), toSqlType(inputType)))); } ProjectNode projectNode = new ProjectNode( context.getIdAllocator().getNextId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index ba10bd677e50..aecf89185fc2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -608,18 +608,6 @@ public PlanWithProperties visitSort(SortNode node, PreferredProperties preferred { PlanWithProperties child = planChild(node, PreferredProperties.undistributed()); - if (child.getProperties().isSingleNode()) { - // current plan so far is single node, so local properties are effectively global properties - // skip the SortNode if the local properties guarantee ordering on Sort keys - // TODO: This should be extracted as a separate optimizer once the planner is able to reason about the ordering of each operator - List> desiredProperties = node.getOrderingScheme().toLocalProperties(); - - if (LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).stream() - .noneMatch(Optional::isPresent)) { - return child; - } - } - if (isDistributedSortEnabled(session)) { child = planChild(node, PreferredProperties.any()); // insert round robin exchange to eliminate skewness issues diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index b62a0a8d7006..991ea2f25b8f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -56,6 +56,7 @@ import static io.trino.SystemSessionProperties.isOptimizeDistinctAggregationEnabled; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -374,7 +375,7 @@ else if (aggregationOutputSymbolsMap.containsKey(symbol)) { // add null assignment for mask // unused mask will be removed by PruneUnreferencedOutputs - outputSymbols.put(aggregateInfo.getMask(), new NullLiteral()); + outputSymbols.put(aggregateInfo.getMask(), new Cast(new NullLiteral(), toSqlType(BOOLEAN))); aggregateInfo.setNewNonDistinctAggregateSymbols(outputNonDistinctAggregateSymbols.buildOrThrow()); diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java index f79b1dcc072c..9f917d2e6d44 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java @@ -24,6 +24,7 @@ import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.parser.SqlParser; import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.DescribeInput; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Limit; @@ -32,6 +33,7 @@ import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; +import io.trino.sql.tree.Query; import io.trino.sql.tree.Row; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; @@ -42,6 +44,8 @@ import static io.trino.SystemSessionProperties.isOmitDateTimeTypePrecision; import static io.trino.execution.ParameterExtractor.extractParameters; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.QueryUtil.aliased; import static io.trino.sql.QueryUtil.ascending; import static io.trino.sql.QueryUtil.identifier; @@ -51,6 +55,7 @@ import static io.trino.sql.QueryUtil.simpleQuery; import static io.trino.sql.QueryUtil.values; import static io.trino.sql.analyzer.QueryType.DESCRIBE; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.type.TypeUtils.getDisplayLabel; import static io.trino.type.UnknownType.UNKNOWN; import static java.util.Objects.requireNonNull; @@ -82,6 +87,12 @@ public Statement rewrite( private static final class Visitor extends AstVisitor { + private static final Query EMPTY_INPUT = createDesctibeInputQuery( + new Row[]{row( + new Cast(new NullLiteral(), toSqlType(BIGINT)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)))}, + Optional.of(new Limit(new LongLiteral("0")))); + private final Session session; private final SqlParser parser; private final AnalyzerFactory analyzerFactory; @@ -130,10 +141,14 @@ protected Node visitDescribeInput(DescribeInput node, Void context) Row[] rows = builder.build().toArray(Row[]::new); Optional limit = Optional.empty(); if (rows.length == 0) { - rows = new Row[] {row(new NullLiteral(), new NullLiteral())}; - limit = Optional.of(new Limit(new LongLiteral("0"))); + return EMPTY_INPUT; } + return createDesctibeInputQuery(rows, limit); + } + + private static Query createDesctibeInputQuery(Row[] rows, Optional limit) + { return simpleQuery( selectList(identifier("Position"), identifier("Type")), aliased( diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java index b0ae19fe9821..b8a78d502e29 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java @@ -27,6 +27,7 @@ import io.trino.sql.parser.SqlParser; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.BooleanLiteral; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.DescribeOutput; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Limit; @@ -35,6 +36,7 @@ import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; +import io.trino.sql.tree.Query; import io.trino.sql.tree.Row; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; @@ -44,6 +46,9 @@ import java.util.Optional; import static io.trino.SystemSessionProperties.isOmitDateTimeTypePrecision; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.QueryUtil.aliased; import static io.trino.sql.QueryUtil.identifier; import static io.trino.sql.QueryUtil.row; @@ -51,6 +56,7 @@ import static io.trino.sql.QueryUtil.simpleQuery; import static io.trino.sql.QueryUtil.values; import static io.trino.sql.analyzer.QueryType.DESCRIBE; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.type.TypeUtils.getDisplayLabel; import static java.util.Objects.requireNonNull; @@ -81,6 +87,17 @@ public Statement rewrite( private static final class Visitor extends AstVisitor { + private static final Query EMPTY_OUTPUT = createDesctibeOutputQuery( + new Row[]{row( + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(BIGINT)), + new Cast(new NullLiteral(), toSqlType(BOOLEAN)))}, + Optional.of(new Limit(new LongLiteral("0")))); + private final Session session; private final SqlParser parser; private final AnalyzerFactory analyzerFactory; @@ -119,10 +136,13 @@ protected Node visitDescribeOutput(DescribeOutput node, Void context) Optional limit = Optional.empty(); Row[] rows = analysis.getRootScope().getRelationType().getVisibleFields().stream().map(field -> createDescribeOutputRow(field, analysis)).toArray(Row[]::new); if (rows.length == 0) { - NullLiteral nullLiteral = new NullLiteral(); - rows = new Row[] {row(nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral)}; - limit = Optional.of(new Limit(new LongLiteral("0"))); + return EMPTY_OUTPUT; } + return createDesctibeOutputQuery(rows, limit); + } + + private static Query createDesctibeOutputQuery(Row[] rows, Optional limit) + { return simpleQuery( selectList( identifier("Column Name"), diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java index 2f3199c29d27..54c5ddec92f2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java @@ -54,6 +54,7 @@ import io.trino.spi.security.PrincipalType; import io.trino.spi.security.TrinoPrincipal; import io.trino.spi.session.PropertyMetadata; +import io.trino.spi.type.Type; import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.parser.ParsingException; import io.trino.sql.parser.SqlParser; @@ -61,6 +62,7 @@ import io.trino.sql.tree.Array; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.BooleanLiteral; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateSchema; @@ -75,13 +77,16 @@ import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; import io.trino.sql.tree.PrincipalSpecification; import io.trino.sql.tree.Property; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.Query; +import io.trino.sql.tree.QuerySpecification; import io.trino.sql.tree.Relation; import io.trino.sql.tree.Row; +import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.ShowCatalogs; import io.trino.sql.tree.ShowColumns; import io.trino.sql.tree.ShowCreate; @@ -92,6 +97,7 @@ import io.trino.sql.tree.ShowSchemas; import io.trino.sql.tree.ShowSession; import io.trino.sql.tree.ShowTables; +import io.trino.sql.tree.SingleColumn; import io.trino.sql.tree.SortItem; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; @@ -128,18 +134,19 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.SCHEMA_NOT_FOUND; import static io.trino.spi.StandardErrorCode.TABLE_NOT_FOUND; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ExpressionUtils.combineConjuncts; import static io.trino.sql.QueryUtil.aliased; import static io.trino.sql.QueryUtil.aliasedName; import static io.trino.sql.QueryUtil.aliasedNullToEmpty; import static io.trino.sql.QueryUtil.ascending; -import static io.trino.sql.QueryUtil.emptyQuery; import static io.trino.sql.QueryUtil.equal; import static io.trino.sql.QueryUtil.functionCall; import static io.trino.sql.QueryUtil.identifier; import static io.trino.sql.QueryUtil.logicalAnd; import static io.trino.sql.QueryUtil.ordering; +import static io.trino.sql.QueryUtil.query; import static io.trino.sql.QueryUtil.row; import static io.trino.sql.QueryUtil.selectAll; import static io.trino.sql.QueryUtil.selectList; @@ -377,13 +384,13 @@ protected Node visitShowRoles(ShowRoles node, Void context) List rows = enabledRoles.stream() .map(role -> row(new StringLiteral(role))) .collect(toList()); - return singleColumnValues(rows, "Role"); + return singleColumnValues(rows, "Role", VARCHAR); } accessControl.checkCanShowRoles(session.toSecurityContext(), catalog); List rows = metadata.listRoles(session, catalog).stream() .map(role -> row(new StringLiteral(role))) .collect(toList()); - return singleColumnValues(rows, "Role"); + return singleColumnValues(rows, "Role", VARCHAR); } @Override @@ -402,14 +409,14 @@ protected Node visitShowRoleGrants(ShowRoleGrants node, Void context) .map(roleGrant -> row(new StringLiteral(roleGrant.getRoleName()))) .collect(toList()); - return singleColumnValues(rows, "Role Grants"); + return singleColumnValues(rows, "Role Grants", VARCHAR); } - private static Query singleColumnValues(List rows, String columnName) + private static Query singleColumnValues(List rows, String columnName, Type type) { List columns = ImmutableList.of(columnName); if (rows.isEmpty()) { - return emptyQuery(columns); + return emptyQuery(columns, ImmutableList.of(type)); } return simpleQuery( selectList(new AllColumns()), @@ -803,7 +810,7 @@ protected Node visitShowFunctions(ShowFunctions node, Void context) .buildOrThrow(); if (rows.isEmpty()) { - return emptyQuery(ImmutableList.copyOf(columns.values())); + return emptyQuery(ImmutableList.copyOf(columns.values()), ImmutableList.of(VARCHAR, VARCHAR, VARCHAR, VARCHAR, BOOLEAN, VARCHAR)); } return simpleQuery( @@ -949,5 +956,24 @@ protected Node visitNode(Node node, Void context) { return node; } + + public static Query emptyQuery(List columns, List types) + { + ImmutableList.Builder items = ImmutableList.builder(); + for (int i = 0; i < columns.size(); i++) { + items.add(new SingleColumn(new Cast(new NullLiteral(), toSqlType(types.get(i))), identifier(columns.get(i)))); + } + Optional where = Optional.of(FALSE_LITERAL); + return query(new QuerySpecification( + selectAll(items.build()), + Optional.empty(), + where, + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty())); + } } } diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 0e6ce21f962d..c3947b7a9d6d 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -168,6 +168,7 @@ import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; import io.trino.sql.parser.SqlParser; +import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.LogicalPlanner; @@ -999,7 +1000,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out typeOperators, tableExecuteContextManager, exchangeManagerRegistry, - nodeManager.getCurrentNode().getNodeVersion()); + nodeManager.getCurrentNode().getNodeVersion(), + new CompilerConfig()); // plan query LocalExecutionPlan localExecutionPlan = executionPlanner.plan( diff --git a/core/trino-main/src/main/java/io/trino/type/CodePointsType.java b/core/trino-main/src/main/java/io/trino/type/CodePointsType.java index 699ab2ce504a..c5c33f9afc92 100644 --- a/core/trino-main/src/main/java/io/trino/type/CodePointsType.java +++ b/core/trino-main/src/main/java/io/trino/type/CodePointsType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -46,7 +47,9 @@ public Object getObject(Block block, int position) return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice slice = valueBlock.getSlice(valuePosition); int[] codePoints = new int[slice.length() / Integer.BYTES]; slice.getInts(0, codePoints); return codePoints; diff --git a/core/trino-main/src/main/java/io/trino/type/ColorType.java b/core/trino-main/src/main/java/io/trino/type/ColorType.java index e6329872dcef..cac180fbbc6d 100644 --- a/core/trino-main/src/main/java/io/trino/type/ColorType.java +++ b/core/trino-main/src/main/java/io/trino/type/ColorType.java @@ -46,7 +46,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - int color = block.getInt(position, 0); + int color = getInt(block, position); if (color < 0) { return ColorFunctions.SystemColor.valueOf(-(color + 1)).getName(); } diff --git a/core/trino-main/src/main/java/io/trino/type/FunctionType.java b/core/trino-main/src/main/java/io/trino/type/FunctionType.java index 5757fe91c7e0..2a5e6a790f4c 100644 --- a/core/trino-main/src/main/java/io/trino/type/FunctionType.java +++ b/core/trino-main/src/main/java/io/trino/type/FunctionType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -92,7 +93,13 @@ public String getDisplayName() @Override public final Class getJavaType() { - throw new UnsupportedOperationException(getTypeSignature() + " type does not have Java type"); + throw new UnsupportedOperationException(getTypeSignature() + " type does not have a Java type"); + } + + @Override + public Class getValueBlockType() + { + throw new UnsupportedOperationException(getTypeSignature() + " type does not have a ValueBlock type"); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java b/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java index 0cb9ed359e9d..ebcfe8dea965 100644 --- a/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java +++ b/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java @@ -35,7 +35,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - return new SqlIntervalYearMonth(block.getInt(position, 0)); + return new SqlIntervalYearMonth(getInt(block, position)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/IpAddressType.java b/core/trino-main/src/main/java/io/trino/type/IpAddressType.java index b248ec808175..c0bff38df5bc 100644 --- a/core/trino-main/src/main/java/io/trino/type/IpAddressType.java +++ b/core/trino-main/src/main/java/io/trino/type/IpAddressType.java @@ -20,6 +20,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -63,7 +64,7 @@ public class IpAddressType private IpAddressType() { - super(new TypeSignature(StandardTypes.IPADDRESS), Slice.class); + super(new TypeSignature(StandardTypes.IPADDRESS), Slice.class, Int128ArrayBlock.class); } @Override @@ -219,7 +220,7 @@ private static boolean equalOperator(Slice left, Slice right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return equal( leftBlock.getLong(leftPosition, 0), @@ -240,7 +241,7 @@ private static long xxHash64Operator(Slice value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) { return xxHash64(block.getLong(position, 0), block.getLong(position, SIZE_OF_LONG)); } @@ -261,7 +262,7 @@ private static long comparisonOperator(Slice left, Slice right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return compareBigEndian( leftBlock.getLong(leftPosition, 0), diff --git a/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java b/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java index 6bc6bf3587b9..606639ce4110 100644 --- a/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java +++ b/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -47,7 +48,9 @@ public Object getObject(Block block, int position) return null; } - return joniRegexp(block.getSlice(position, 0, block.getSliceLength(position))); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return joniRegexp(valueBlock.getSlice(valuePosition)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/Json2016Type.java b/core/trino-main/src/main/java/io/trino/type/Json2016Type.java index 5480dc224707..f028551ddc32 100644 --- a/core/trino-main/src/main/java/io/trino/type/Json2016Type.java +++ b/core/trino-main/src/main/java/io/trino/type/Json2016Type.java @@ -21,6 +21,7 @@ import io.trino.operator.scalar.json.JsonOutputConversionError; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -54,7 +55,9 @@ public Object getObject(Block block, int position) return null; } - String json = block.getSlice(position, 0, block.getSliceLength(position)).toStringUtf8(); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String json = valueBlock.getSlice(valuePosition).toStringUtf8(); if (json.equals(JSON_ERROR.toString())) { return JSON_ERROR; } diff --git a/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java b/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java index 02fd29daa304..fc7f6ed88dbc 100644 --- a/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java +++ b/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -57,8 +58,10 @@ public Object getObject(Block block, int position) return null; } - Slice bytes = block.getSlice(position, 0, block.getSliceLength(position)); - return jsonPathCodec.fromJson(bytes.toStringUtf8()); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String json = valueBlock.getSlice(valuePosition).toStringUtf8(); + return jsonPathCodec.fromJson(json); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/JsonPathType.java b/core/trino-main/src/main/java/io/trino/type/JsonPathType.java index 767ea5075626..addc7a034f4f 100644 --- a/core/trino-main/src/main/java/io/trino/type/JsonPathType.java +++ b/core/trino-main/src/main/java/io/trino/type/JsonPathType.java @@ -18,6 +18,7 @@ import io.trino.operator.scalar.JsonPath; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -47,7 +48,10 @@ public Object getObject(Block block, int position) return null; } - return new JsonPath(block.getSlice(position, 0, block.getSliceLength(position)).toStringUtf8()); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String pattern = valueBlock.getSlice(valuePosition).toStringUtf8(); + return new JsonPath(pattern); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/JsonType.java b/core/trino-main/src/main/java/io/trino/type/JsonType.java index a2d95bdbe10e..f077f2287fb3 100644 --- a/core/trino-main/src/main/java/io/trino/type/JsonType.java +++ b/core/trino-main/src/main/java/io/trino/type/JsonType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -62,13 +63,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getSlice(position, 0, block.getSliceLength(position)).toStringUtf8(); + return getSlice(block, position).toStringUtf8(); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } public void writeString(BlockBuilder blockBuilder, String value) diff --git a/core/trino-main/src/main/java/io/trino/type/LikePatternType.java b/core/trino-main/src/main/java/io/trino/type/LikePatternType.java index 8f9666e77779..180f9a33fa2a 100644 --- a/core/trino-main/src/main/java/io/trino/type/LikePatternType.java +++ b/core/trino-main/src/main/java/io/trino/type/LikePatternType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -23,8 +24,8 @@ import java.util.Optional; -import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.Slices.utf8Slice; +import static java.nio.charset.StandardCharsets.UTF_8; public class LikePatternType extends AbstractVariableWidthType @@ -50,19 +51,19 @@ public Object getObject(Block block, int position) return null; } + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice slice = valueBlock.getSlice(valuePosition); + // layout is: ? - int offset = 0; - int length = block.getInt(position, offset); - offset += SIZE_OF_INT; - String pattern = block.getSlice(position, offset, length).toStringUtf8(); - offset += length; + int length = slice.getInt(0); + String pattern = slice.toString(4, length, UTF_8); - boolean hasEscape = block.getByte(position, offset) != 0; - offset++; + boolean hasEscape = slice.getByte(4 + length) != 0; Optional escape = Optional.empty(); if (hasEscape) { - escape = Optional.of((char) block.getInt(position, offset)); + escape = Optional.of((char) slice.getInt(4 + length + 1)); } return LikePattern.compile(pattern, escape); diff --git a/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java b/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java index 98b8690725bb..1d6807fd4d71 100644 --- a/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java +++ b/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java @@ -18,6 +18,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -54,7 +55,9 @@ public Object getObject(Block block, int position) return null; } - Slice pattern = block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice pattern = valueBlock.getSlice(valuePosition); try { return new Re2JRegexp(dfaStatesLimit, dfaRetries, pattern); } diff --git a/core/trino-main/src/main/java/io/trino/type/TDigestType.java b/core/trino-main/src/main/java/io/trino/type/TDigestType.java index a49fafb84622..b37130082a3b 100644 --- a/core/trino-main/src/main/java/io/trino/type/TDigestType.java +++ b/core/trino-main/src/main/java/io/trino/type/TDigestType.java @@ -17,6 +17,7 @@ import io.airlift.stats.TDigest; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -37,7 +38,9 @@ private TDigestType() @Override public Object getObject(Block block, int position) { - return TDigest.deserialize(block.getSlice(position, 0, block.getSliceLength(position))); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return TDigest.deserialize(valueBlock.getSlice(valuePosition)); } @Override @@ -54,6 +57,8 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return new SqlVarbinary(valueBlock.getSlice(valuePosition).getBytes()); } } diff --git a/core/trino-main/src/main/java/io/trino/type/UnknownType.java b/core/trino-main/src/main/java/io/trino/type/UnknownType.java index 6fd9e1c31703..92406f8684f2 100644 --- a/core/trino-main/src/main/java/io/trino/type/UnknownType.java +++ b/core/trino-main/src/main/java/io/trino/type/UnknownType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.ByteArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -51,7 +52,7 @@ private UnknownType() // We never access the native container for UNKNOWN because its null check is always true. // The actual native container type does not matter here. // We choose boolean to represent UNKNOWN because it's the smallest primitive type. - super(new TypeSignature(NAME), boolean.class); + super(new TypeSignature(NAME), boolean.class, ByteArrayBlock.class); } @Override @@ -122,8 +123,8 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public boolean getBoolean(Block block, int position) { - // Ideally, this function should never be invoked for unknown type. - // However, some logic rely on having a default value before the null check. + // Ideally, this function should never be invoked for the unknown type. + // However, some logic relies on having a default value before the null check. checkArgument(block.isNull(position)); return false; } @@ -132,8 +133,8 @@ public boolean getBoolean(Block block, int position) @Override public void writeBoolean(BlockBuilder blockBuilder, boolean value) { - // Ideally, this function should never be invoked for unknown type. - // However, some logic (e.g. AbstractMinMaxBy) rely on writing a default value before the null check. + // Ideally, this function should never be invoked for the unknown type. + // However, some logic (e.g. AbstractMinMaxBy) relies on writing a default value before the null check. checkArgument(!value); blockBuilder.appendNull(); } diff --git a/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java b/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java index d2094c6c2f77..7d3195992756 100644 --- a/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java +++ b/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -44,13 +45,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java index 8da257d28d24..986a5fc8cca9 100644 --- a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java +++ b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java @@ -17,7 +17,6 @@ import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; -import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; @@ -26,12 +25,13 @@ import io.trino.spi.block.DictionaryId; import io.trino.spi.block.MapHashTables; import io.trino.spi.block.TestingBlockEncodingSerde; -import io.trino.spi.block.VariableWidthBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; import java.lang.reflect.Array; import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.List; @@ -84,6 +84,10 @@ protected void assertBlock(Block block, T[] expectedValues) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Invalid position %d in block with %d positions", block.getPositionCount(), block.getPositionCount()); } + + if (block instanceof ValueBlock valueBlock) { + assertBlockClassImplementation(valueBlock.getClass()); + } } private void assertRetainedSize(Block block) @@ -113,7 +117,7 @@ else if (type == BlockBuilderStatus.class) { retainedSize += BlockBuilderStatus.INSTANCE_SIZE; } } - else if (type == Block.class) { + else if (Block.class.isAssignableFrom(type)) { retainedSize += ((Block) field.get(block)).getRetainedSizeInBytes(); } else if (type == Block[].class) { @@ -295,7 +299,13 @@ protected void assertPositionValue(Block block, int position, T expectedValu if (isSliceAccessSupported()) { assertEquals(block.getSliceLength(position), expectedSliceValue.length()); - assertSlicePosition(block, position, expectedSliceValue); + + int length = block.getSliceLength(position); + assertEquals(length, expectedSliceValue.length()); + + for (int offset = 0; offset < length - 3; offset++) { + assertEquals(block.getSlice(position, offset, 3), expectedSliceValue.slice(offset, 3)); + } } assertPositionEquals(block, position, expectedSliceValue); @@ -326,34 +336,6 @@ else if (expectedValue instanceof long[][] expected) { } } - protected void assertSlicePosition(Block block, int position, Slice expectedSliceValue) - { - int length = block.getSliceLength(position); - assertEquals(length, expectedSliceValue.length()); - - Block expectedBlock = toSingeValuedBlock(expectedSliceValue); - for (int offset = 0; offset < length - 3; offset++) { - assertEquals(block.getSlice(position, offset, 3), expectedSliceValue.slice(offset, 3)); - assertTrue(block.bytesEqual(position, offset, expectedSliceValue, offset, 3)); - // if your tests fail here, please change your test to not use this value - assertFalse(block.bytesEqual(position, offset, Slices.utf8Slice("XXX"), 0, 3)); - - assertEquals(block.bytesCompare(position, offset, 3, expectedSliceValue, offset, 3), 0); - assertTrue(block.bytesCompare(position, offset, 3, expectedSliceValue, offset, 2) > 0); - Slice greaterSlice = createGreaterValue(expectedSliceValue, offset, 3); - assertTrue(block.bytesCompare(position, offset, 3, greaterSlice, 0, greaterSlice.length()) < 0); - - assertTrue(block.equals(position, offset, expectedBlock, 0, offset, 3)); - assertEquals(block.compareTo(position, offset, 3, expectedBlock, 0, offset, 3), 0); - - VariableWidthBlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, 1); - blockBuilder.writeEntry(block.getSlice(position, offset, 3)); - Block segment = blockBuilder.build(); - - assertTrue(block.equals(position, offset, segment, 0, 0, 3)); - } - } - protected boolean isByteAccessSupported() { return true; @@ -498,4 +480,13 @@ protected static void testIncompactBlock(Block block) assertNotCompact(block); testCopyRegionCompactness(block); } + + private void assertBlockClassImplementation(Class clazz) + { + for (Method method : clazz.getMethods()) { + if (method.getReturnType() == ValueBlock.class && !method.isBridge()) { + throw new AssertionError(format("ValueBlock method %s should override return type to be %s", method, clazz.getSimpleName())); + } + } + } } diff --git a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java index e420c33be0a6..739716f48e50 100644 --- a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java +++ b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java @@ -22,6 +22,7 @@ import io.trino.spi.block.RowBlock; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -138,7 +139,7 @@ public static RunLengthEncodedBlock createRandomRleBlock(Block block, int positi return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(block.getSingleValueBlock(random().nextInt(block.getPositionCount())), positionCount); } - public static Block createRandomBlockForType(Type type, int positionCount, float nullRate) + public static ValueBlock createRandomBlockForType(Type type, int positionCount, float nullRate) { verifyNullRate(nullRate); @@ -191,12 +192,12 @@ public static Block createRandomBlockForType(Type type, int positionCount, float return createRandomBlockForNestedType(type, positionCount, nullRate); } - public static Block createRandomBlockForNestedType(Type type, int positionCount, float nullRate) + public static ValueBlock createRandomBlockForNestedType(Type type, int positionCount, float nullRate) { return createRandomBlockForNestedType(type, positionCount, nullRate, ENTRY_SIZE); } - public static Block createRandomBlockForNestedType(Type type, int positionCount, float nullRate, int maxCardinality) + public static ValueBlock createRandomBlockForNestedType(Type type, int positionCount, float nullRate, int maxCardinality) { // Builds isNull and offsets of size positionCount boolean[] isNull = null; @@ -222,12 +223,12 @@ public static Block createRandomBlockForNestedType(Type type, int positionCount, // Builds the nested block of size offsets[positionCount]. if (type instanceof ArrayType) { - Block valuesBlock = createRandomBlockForType(((ArrayType) type).getElementType(), offsets[positionCount], nullRate); + ValueBlock valuesBlock = createRandomBlockForType(((ArrayType) type).getElementType(), offsets[positionCount], nullRate); return fromElementBlock(positionCount, Optional.ofNullable(isNull), offsets, valuesBlock); } if (type instanceof MapType mapType) { - Block keyBlock = createRandomBlockForType(mapType.getKeyType(), offsets[positionCount], 0.0f); - Block valueBlock = createRandomBlockForType(mapType.getValueType(), offsets[positionCount], nullRate); + ValueBlock keyBlock = createRandomBlockForType(mapType.getKeyType(), offsets[positionCount], 0.0f); + ValueBlock valueBlock = createRandomBlockForType(mapType.getValueType(), offsets[positionCount], nullRate); return mapType.createBlockFromKeyValue(Optional.ofNullable(isNull), offsets, keyBlock, valueBlock); } @@ -245,19 +246,19 @@ public static Block createRandomBlockForNestedType(Type type, int positionCount, throw new IllegalArgumentException(format("type %s is not supported.", type)); } - public static Block createRandomBooleansBlock(int positionCount, float nullRate) + public static ValueBlock createRandomBooleansBlock(int positionCount, float nullRate) { Random random = random(); return createBooleansBlock(generateListWithNulls(positionCount, nullRate, random::nextBoolean)); } - public static Block createRandomIntsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomIntsBlock(int positionCount, float nullRate) { Random random = random(); return createIntsBlock(generateListWithNulls(positionCount, nullRate, random::nextInt)); } - public static Block createRandomLongDecimalsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomLongDecimalsBlock(int positionCount, float nullRate) { Random random = random(); return createLongDecimalsBlock(generateListWithNulls( @@ -266,7 +267,7 @@ public static Block createRandomLongDecimalsBlock(int positionCount, float nullR () -> String.valueOf(random.nextLong()))); } - public static Block createRandomShortTimestampBlock(TimestampType type, int positionCount, float nullRate) + public static ValueBlock createRandomShortTimestampBlock(TimestampType type, int positionCount, float nullRate) { Random random = random(); return createLongsBlock( @@ -276,7 +277,7 @@ public static Block createRandomShortTimestampBlock(TimestampType type, int posi () -> SqlTimestamp.fromMillis(type.getPrecision(), random.nextLong()).getEpochMicros())); } - public static Block createRandomLongTimestampBlock(TimestampType type, int positionCount, float nullRate) + public static ValueBlock createRandomLongTimestampBlock(TimestampType type, int positionCount, float nullRate) { Random random = random(); return createLongTimestampBlock( @@ -290,7 +291,7 @@ public static Block createRandomLongTimestampBlock(TimestampType type, int posit })); } - public static Block createRandomLongsBlock(int positionCount, int numberOfUniqueValues) + public static ValueBlock createRandomLongsBlock(int positionCount, int numberOfUniqueValues) { checkArgument(positionCount >= numberOfUniqueValues, "numberOfUniqueValues must be between 1 and positionCount: %s but was %s", positionCount, numberOfUniqueValues); int[] uniqueValues = chooseRandomUnique(positionCount, numberOfUniqueValues).stream() @@ -303,13 +304,13 @@ public static Block createRandomLongsBlock(int positionCount, int numberOfUnique .collect(toImmutableList())); } - public static Block createRandomLongsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomLongsBlock(int positionCount, float nullRate) { Random random = random(); return createLongsBlock(generateListWithNulls(positionCount, nullRate, random::nextLong)); } - public static Block createRandomSmallintsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomSmallintsBlock(int positionCount, float nullRate) { Random random = random(); return createTypedLongsBlock( @@ -317,43 +318,43 @@ public static Block createRandomSmallintsBlock(int positionCount, float nullRate generateListWithNulls(positionCount, nullRate, () -> (long) (short) random.nextLong())); } - public static Block createRandomStringBlock(int positionCount, float nullRate, int maxStringLength) + public static ValueBlock createRandomStringBlock(int positionCount, float nullRate, int maxStringLength) { return createStringsBlock( generateListWithNulls(positionCount, nullRate, () -> generateRandomStringWithLength(maxStringLength))); } - private static Block createRandomVarbinariesBlock(int positionCount, float nullRate) + private static ValueBlock createRandomVarbinariesBlock(int positionCount, float nullRate) { Random random = random(); return createSlicesBlock(VARBINARY, generateListWithNulls(positionCount, nullRate, () -> Slices.random(16, random))); } - private static Block createRandomUUIDsBlock(int positionCount, float nullRate) + private static ValueBlock createRandomUUIDsBlock(int positionCount, float nullRate) { Random random = random(); return createSlicesBlock(UUID, generateListWithNulls(positionCount, nullRate, () -> Slices.random(16, random))); } - private static Block createRandomIpAddressesBlock(int positionCount, float nullRate) + private static ValueBlock createRandomIpAddressesBlock(int positionCount, float nullRate) { Random random = random(); return createSlicesBlock(IPADDRESS, generateListWithNulls(positionCount, nullRate, () -> Slices.random(16, random))); } - private static Block createRandomTinyintsBlock(int positionCount, float nullRate) + private static ValueBlock createRandomTinyintsBlock(int positionCount, float nullRate) { Random random = random(); return createTypedLongsBlock(TINYINT, generateListWithNulls(positionCount, nullRate, () -> (long) (byte) random.nextLong())); } - public static Block createRandomDoublesBlock(int positionCount, float nullRate) + public static ValueBlock createRandomDoublesBlock(int positionCount, float nullRate) { Random random = random(); return createDoublesBlock(generateListWithNulls(positionCount, nullRate, random::nextDouble)); } - public static Block createRandomCharsBlock(CharType charType, int positionCount, float nullRate) + public static ValueBlock createRandomCharsBlock(CharType charType, int positionCount, float nullRate) { return createCharsBlock(charType, generateListWithNulls(positionCount, nullRate, () -> generateRandomStringWithLength(charType.getLength()))); } @@ -379,14 +380,14 @@ public static Set chooseNullPositions(int positionCount, float nullRate return chooseRandomUnique(positionCount, nullCount); } - public static Block createStringsBlock(String... values) + public static ValueBlock createStringsBlock(String... values) { requireNonNull(values, "values is null"); return createStringsBlock(Arrays.asList(values)); } - public static Block createStringsBlock(Iterable values) + public static ValueBlock createStringsBlock(Iterable values) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, 100); @@ -399,26 +400,26 @@ public static Block createStringsBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createSlicesBlock(Slice... values) + public static ValueBlock createSlicesBlock(Slice... values) { requireNonNull(values, "values is null"); return createSlicesBlock(Arrays.asList(values)); } - public static Block createSlicesBlock(Iterable values) + public static ValueBlock createSlicesBlock(Iterable values) { return createSlicesBlock(VARBINARY, values); } - public static Block createSlicesBlock(Type type, Iterable values) + public static ValueBlock createSlicesBlock(Type type, Iterable values) { return createBlock(type, type::writeSlice, values); } - public static Block createStringSequenceBlock(int start, int end) + public static ValueBlock createStringSequenceBlock(int start, int end) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, 100); @@ -426,7 +427,7 @@ public static Block createStringSequenceBlock(int start, int end) VARCHAR.writeString(builder, String.valueOf(i)); } - return builder.build(); + return builder.buildValueBlock(); } public static Block createStringDictionaryBlock(int start, int length) @@ -445,7 +446,7 @@ public static Block createStringDictionaryBlock(int start, int length) return DictionaryBlock.create(ids.length, builder.build(), ids); } - public static Block createStringArraysBlock(Iterable> values) + public static ValueBlock createStringArraysBlock(Iterable> values) { ArrayType arrayType = new ArrayType(VARCHAR); BlockBuilder builder = arrayType.createBlockBuilder(null, 100); @@ -459,22 +460,22 @@ public static Block createStringArraysBlock(Iterable> } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createBooleansBlock(Boolean... values) + public static ValueBlock createBooleansBlock(Boolean... values) { requireNonNull(values, "values is null"); return createBooleansBlock(Arrays.asList(values)); } - public static Block createBooleansBlock(Boolean value, int count) + public static ValueBlock createBooleansBlock(Boolean value, int count) { return createBooleansBlock(Collections.nCopies(count, value)); } - public static Block createBooleansBlock(Iterable values) + public static ValueBlock createBooleansBlock(Iterable values) { BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 100); @@ -487,17 +488,17 @@ public static Block createBooleansBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createShortDecimalsBlock(String... values) + public static ValueBlock createShortDecimalsBlock(String... values) { requireNonNull(values, "values is null"); return createShortDecimalsBlock(Arrays.asList(values)); } - public static Block createShortDecimalsBlock(Iterable values) + public static ValueBlock createShortDecimalsBlock(Iterable values) { DecimalType shortDecimalType = DecimalType.createDecimalType(1); BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 100); @@ -511,17 +512,17 @@ public static Block createShortDecimalsBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongDecimalsBlock(String... values) + public static ValueBlock createLongDecimalsBlock(String... values) { requireNonNull(values, "values is null"); return createLongDecimalsBlock(Arrays.asList(values)); } - public static Block createLongDecimalsBlock(Iterable values) + public static ValueBlock createLongDecimalsBlock(Iterable values) { DecimalType longDecimalType = DecimalType.createDecimalType(MAX_SHORT_PRECISION + 1); BlockBuilder builder = longDecimalType.createBlockBuilder(null, 100); @@ -535,16 +536,16 @@ public static Block createLongDecimalsBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongTimestampBlock(TimestampType type, LongTimestamp... values) + public static ValueBlock createLongTimestampBlock(TimestampType type, LongTimestamp... values) { requireNonNull(values, "values is null"); return createLongTimestampBlock(type, Arrays.asList(values)); } - public static Block createLongTimestampBlock(TimestampType type, Iterable values) + public static ValueBlock createLongTimestampBlock(TimestampType type, Iterable values) { BlockBuilder builder = type.createBlockBuilder(null, 100); @@ -557,51 +558,51 @@ public static Block createLongTimestampBlock(TimestampType type, Iterable values) + public static ValueBlock createCharsBlock(CharType charType, List values) { return createBlock(charType, charType::writeString, values); } - public static Block createTinyintsBlock(Integer... values) + public static ValueBlock createTinyintsBlock(Integer... values) { requireNonNull(values, "values is null"); return createTinyintsBlock(Arrays.asList(values)); } - public static Block createTinyintsBlock(Iterable values) + public static ValueBlock createTinyintsBlock(Iterable values) { return createBlock(TINYINT, (ValueWriter) TINYINT::writeLong, values); } - public static Block createSmallintsBlock(Integer... values) + public static ValueBlock createSmallintsBlock(Integer... values) { requireNonNull(values, "values is null"); return createSmallintsBlock(Arrays.asList(values)); } - public static Block createSmallintsBlock(Iterable values) + public static ValueBlock createSmallintsBlock(Iterable values) { return createBlock(SMALLINT, (ValueWriter) SMALLINT::writeLong, values); } - public static Block createIntsBlock(Integer... values) + public static ValueBlock createIntsBlock(Integer... values) { requireNonNull(values, "values is null"); return createIntsBlock(Arrays.asList(values)); } - public static Block createIntsBlock(Iterable values) + public static ValueBlock createIntsBlock(Iterable values) { return createBlock(INTEGER, (ValueWriter) INTEGER::writeLong, values); } - public static Block createRowBlock(List fieldTypes, Object[]... rows) + public static ValueBlock createRowBlock(List fieldTypes, Object[]... rows) { RowBlockBuilder rowBlockBuilder = new RowBlockBuilder(fieldTypes, null, 1); for (Object[] row : rows) { @@ -647,16 +648,16 @@ else if (fieldValue instanceof Integer) { }); } - return rowBlockBuilder.build(); + return rowBlockBuilder.buildValueBlock(); } - public static Block createEmptyLongsBlock() + public static ValueBlock createEmptyLongsBlock() { - return BIGINT.createFixedSizeBlockBuilder(0).build(); + return BIGINT.createFixedSizeBlockBuilder(0).buildValueBlock(); } // This method makes it easy to create blocks without having to add an L to every value - public static Block createLongsBlock(int... values) + public static ValueBlock createLongsBlock(int... values) { BlockBuilder builder = BIGINT.createBlockBuilder(null, 100); @@ -664,27 +665,27 @@ public static Block createLongsBlock(int... values) BIGINT.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongsBlock(Long... values) + public static ValueBlock createLongsBlock(Long... values) { requireNonNull(values, "values is null"); return createLongsBlock(Arrays.asList(values)); } - public static Block createLongsBlock(Iterable values) + public static ValueBlock createLongsBlock(Iterable values) { return createTypedLongsBlock(BIGINT, values); } - public static Block createTypedLongsBlock(Type type, Iterable values) + public static ValueBlock createTypedLongsBlock(Type type, Iterable values) { return createBlock(type, type::writeLong, values); } - public static Block createLongSequenceBlock(int start, int end) + public static ValueBlock createLongSequenceBlock(int start, int end) { BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(end - start); @@ -692,7 +693,7 @@ public static Block createLongSequenceBlock(int start, int end) BIGINT.writeLong(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } public static Block createLongDictionaryBlock(int start, int length) @@ -716,34 +717,34 @@ public static Block createLongDictionaryBlock(int start, int length, int diction return DictionaryBlock.create(ids.length, builder.build(), ids); } - public static Block createLongRepeatBlock(int value, int length) + public static ValueBlock createLongRepeatBlock(int value, int length) { BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(length); for (int i = 0; i < length; i++) { BIGINT.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createDoubleRepeatBlock(double value, int length) + public static ValueBlock createDoubleRepeatBlock(double value, int length) { BlockBuilder builder = DOUBLE.createFixedSizeBlockBuilder(length); for (int i = 0; i < length; i++) { DOUBLE.writeDouble(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createTimestampsWithTimeZoneMillisBlock(Long... values) + public static ValueBlock createTimestampsWithTimeZoneMillisBlock(Long... values) { BlockBuilder builder = TIMESTAMP_TZ_MILLIS.createFixedSizeBlockBuilder(values.length); for (long value : values) { TIMESTAMP_TZ_MILLIS.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createBooleanSequenceBlock(int start, int end) + public static ValueBlock createBooleanSequenceBlock(int start, int end) { BlockBuilder builder = BOOLEAN.createFixedSizeBlockBuilder(end - start); @@ -751,17 +752,17 @@ public static Block createBooleanSequenceBlock(int start, int end) BOOLEAN.writeBoolean(builder, i % 2 == 0); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createBlockOfReals(Float... values) + public static ValueBlock createBlockOfReals(Float... values) { requireNonNull(values, "values is null"); return createBlockOfReals(Arrays.asList(values)); } - public static Block createBlockOfReals(Iterable values) + public static ValueBlock createBlockOfReals(Iterable values) { BlockBuilder builder = REAL.createBlockBuilder(null, 100); for (Float value : values) { @@ -772,10 +773,10 @@ public static Block createBlockOfReals(Iterable values) REAL.writeLong(builder, floatToRawIntBits(value)); } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createSequenceBlockOfReal(int start, int end) + public static ValueBlock createSequenceBlockOfReal(int start, int end) { BlockBuilder builder = REAL.createFixedSizeBlockBuilder(end - start); @@ -783,22 +784,22 @@ public static Block createSequenceBlockOfReal(int start, int end) REAL.writeLong(builder, floatToRawIntBits(i)); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createDoublesBlock(Double... values) + public static ValueBlock createDoublesBlock(Double... values) { requireNonNull(values, "values is null"); return createDoublesBlock(Arrays.asList(values)); } - public static Block createDoublesBlock(Iterable values) + public static ValueBlock createDoublesBlock(Iterable values) { return createBlock(DOUBLE, DOUBLE::writeDouble, values); } - public static Block createDoubleSequenceBlock(int start, int end) + public static ValueBlock createDoubleSequenceBlock(int start, int end) { BlockBuilder builder = DOUBLE.createFixedSizeBlockBuilder(end - start); @@ -806,10 +807,10 @@ public static Block createDoubleSequenceBlock(int start, int end) DOUBLE.writeDouble(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createArrayBigintBlock(Iterable> values) + public static ValueBlock createArrayBigintBlock(Iterable> values) { ArrayType arrayType = new ArrayType(BIGINT); BlockBuilder builder = arrayType.createBlockBuilder(null, 100); @@ -823,10 +824,10 @@ public static Block createArrayBigintBlock(Iterable> va } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createDateSequenceBlock(int start, int end) + public static ValueBlock createDateSequenceBlock(int start, int end) { BlockBuilder builder = DATE.createFixedSizeBlockBuilder(end - start); @@ -834,10 +835,10 @@ public static Block createDateSequenceBlock(int start, int end) DATE.writeLong(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createTimestampSequenceBlock(int start, int end) + public static ValueBlock createTimestampSequenceBlock(int start, int end) { BlockBuilder builder = TIMESTAMP_MILLIS.createFixedSizeBlockBuilder(end - start); @@ -845,10 +846,10 @@ public static Block createTimestampSequenceBlock(int start, int end) TIMESTAMP_MILLIS.writeLong(builder, multiplyExact(i, MICROSECONDS_PER_MILLISECOND)); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createShortDecimalSequenceBlock(int start, int end, DecimalType type) + public static ValueBlock createShortDecimalSequenceBlock(int start, int end, DecimalType type) { BlockBuilder builder = type.createFixedSizeBlockBuilder(end - start); long base = BigInteger.TEN.pow(type.getScale()).longValue(); @@ -857,10 +858,10 @@ public static Block createShortDecimalSequenceBlock(int start, int end, DecimalT type.writeLong(builder, base * i); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongDecimalSequenceBlock(int start, int end, DecimalType type) + public static ValueBlock createLongDecimalSequenceBlock(int start, int end, DecimalType type) { BlockBuilder builder = type.createFixedSizeBlockBuilder(end - start); BigInteger base = BigInteger.TEN.pow(type.getScale()); @@ -869,25 +870,25 @@ public static Block createLongDecimalSequenceBlock(int start, int end, DecimalTy type.writeObject(builder, Int128.valueOf(BigInteger.valueOf(i).multiply(base))); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createColorRepeatBlock(int value, int length) + public static ValueBlock createColorRepeatBlock(int value, int length) { BlockBuilder builder = COLOR.createFixedSizeBlockBuilder(length); for (int i = 0; i < length; i++) { COLOR.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createColorSequenceBlock(int start, int end) + public static ValueBlock createColorSequenceBlock(int start, int end) { BlockBuilder builder = COLOR.createBlockBuilder(null, end - start); for (int i = start; i < end; ++i) { COLOR.writeLong(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } public static Block createRepeatedValuesBlock(double value, int positionCount) @@ -904,7 +905,7 @@ public static Block createRepeatedValuesBlock(long value, int positionCount) return RunLengthEncodedBlock.create(blockBuilder.build(), positionCount); } - private static Block createBlock(Type type, ValueWriter valueWriter, Iterable values) + private static ValueBlock createBlock(Type type, ValueWriter valueWriter, Iterable values) { BlockBuilder builder = type.createBlockBuilder(null, 100); @@ -917,7 +918,7 @@ private static Block createBlock(Type type, ValueWriter valueWriter, Iter } } - return builder.build(); + return builder.buildValueBlock(); } private interface ValueWriter diff --git a/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java b/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java index 6c08e69f4f0a..7eb0dac72e12 100644 --- a/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java +++ b/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java @@ -138,7 +138,7 @@ private static void assertInvalidPosition(Block block, int[] positions, int offs { assertThatThrownBy(() -> block.getPositions(positions, offset, length).getLong(0, 0)) .isInstanceOfAny(IllegalArgumentException.class, IndexOutOfBoundsException.class) - .hasMessage("Invalid position %d in block with %d positions", positions[0], block.getPositionCount()); + .hasMessage("Invalid position %d and length 1 in block with %d positions", positions[0], block.getPositionCount()); } private static void assertInvalidOffset(Block block, int[] positions, int offset, int length) diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index d5c81aee99ba..6702a7e5bd2b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -43,6 +43,7 @@ import io.trino.sql.gen.JoinFilterFunctionCompiler; import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; +import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.Partitioning; @@ -177,7 +178,8 @@ public static LocalExecutionPlanner createTestingPlanner() PLANNER_CONTEXT.getTypeOperators(), new TableExecuteContextManager(), new ExchangeManagerRegistry(), - new NodeVersion("test")); + new NodeVersion("test"), + new CompilerConfig()); } public static TaskInfo updateTask(SqlTask sqlTask, List splitAssignments, OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java index 5330327109a0..b61ab8959d9d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java @@ -21,6 +21,8 @@ import io.trino.Session; import io.trino.connector.informationschema.InformationSchemaTable; import io.trino.connector.informationschema.InformationSchemaTableHandle; +import io.trino.connector.system.GlobalSystemConnector; +import io.trino.connector.system.SystemTableHandle; import io.trino.cost.StatsAndCosts; import io.trino.metadata.TableHandle; import io.trino.operator.RetryPolicy; @@ -130,6 +132,34 @@ public void testRemoteFromInformationSchemaAndTpchTableScans() assertThat(estimator).isInstanceOf(MockDelegatePatitionMemoryEstimator.class); } + @Test + public void testSystemJdbcTableScan() + { + PartitionMemoryEstimator estimator = createEstimator(tableScanPlanFragment( + "ts", + new TableHandle( + GlobalSystemConnector.CATALOG_HANDLE, + new SystemTableHandle("jdbc", "tables", TupleDomain.all()), + TestingTransactionHandle.create()))); + assertThat(estimator).isInstanceOf(NoMemoryPartitionMemoryEstimator.class); + PartitionMemoryEstimator.MemoryRequirements noMemoryRequirements = new PartitionMemoryEstimator.MemoryRequirements(DataSize.ofBytes(0)); + assertThat(estimator.getInitialMemoryRequirements()).isEqualTo(noMemoryRequirements); + } + + @Test + public void testSystemMetadataTableScan() + { + PartitionMemoryEstimator estimator = createEstimator(tableScanPlanFragment( + "ts", + new TableHandle( + GlobalSystemConnector.CATALOG_HANDLE, + new SystemTableHandle("metadata", "blah", TupleDomain.all()), + TestingTransactionHandle.create()))); + assertThat(estimator).isInstanceOf(NoMemoryPartitionMemoryEstimator.class); + PartitionMemoryEstimator.MemoryRequirements noMemoryRequirements = new PartitionMemoryEstimator.MemoryRequirements(DataSize.ofBytes(0)); + assertThat(estimator.getInitialMemoryRequirements()).isEqualTo(noMemoryRequirements); + } + private static PlanFragment getParentFragment(PlanFragment... childFragments) { ImmutableList childFragmentIds = Stream.of(childFragments) @@ -161,13 +191,18 @@ private PartitionMemoryEstimator createEstimator(PlanFragment planFragment, Plan } private static PlanFragment tableScanPlanFragment(String fragmentId, ConnectorTableHandle tableHandle) + { + return tableScanPlanFragment(fragmentId, new TableHandle( + TEST_CATALOG_HANDLE, + tableHandle, + TestingTransactionHandle.create())); + } + + private static PlanFragment tableScanPlanFragment(String fragmentId, TableHandle tableHandle) { TableScanNode informationSchemaViewsTableScan = new TableScanNode( new PlanNodeId("tableScan"), - new TableHandle( - TEST_CATALOG_HANDLE, - tableHandle, - TestingTransactionHandle.create()), + tableHandle, ImmutableList.of(), ImmutableMap.of(), TupleDomain.all(), diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java index 00d91fbe8514..174820181ad1 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java @@ -38,6 +38,7 @@ import io.trino.security.AllowAllAccessControl; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.function.AggregationState; @@ -388,7 +389,7 @@ public static final class BlockInputAggregationFunction @InputFunction public static void input( @AggregationState NullableDoubleState state, - @BlockPosition @SqlType(DOUBLE) Block value, + @BlockPosition @SqlType(DOUBLE) ValueBlock value, @BlockIndex int id) { // noop this is only for annotation testing puproses diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java index 0c788312c62e..0e5c486c22df 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java @@ -53,6 +53,7 @@ import static io.trino.block.BlockAssertions.createDoubleRepeatBlock; import static io.trino.block.BlockAssertions.createDoubleSequenceBlock; import static io.trino.block.BlockAssertions.createDoublesBlock; +import static io.trino.block.BlockAssertions.createIntsBlock; import static io.trino.block.BlockAssertions.createLongRepeatBlock; import static io.trino.block.BlockAssertions.createLongSequenceBlock; import static io.trino.block.BlockAssertions.createLongsBlock; @@ -280,9 +281,9 @@ public void testCollectWithNulls() OperatorFactory operatorFactory = createOperatorFactory(channel(0, INTEGER)); verifyPassthrough(createOperator(operatorFactory), ImmutableList.of(INTEGER), - new Page(createLongsBlock(1, 2, 3)), + new Page(createIntsBlock(1, 2, 3)), new Page(blockWithNulls), - new Page(createLongsBlock(4, 5))); + new Page(createIntsBlock(4, 5))); operatorFactory.noMoreOperators(); assertEquals(partitions.build(), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java b/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java new file mode 100644 index 000000000000..166f8c12ed48 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java @@ -0,0 +1,124 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.sql.gen.JoinCompiler; +import io.trino.testing.TestingSession; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static io.trino.block.BlockAssertions.createRandomBlockForType; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampType.TIMESTAMP_NANOS; +import static io.trino.spi.type.TimestampType.TIMESTAMP_PICOS; +import static io.trino.spi.type.TimestampType.TIMESTAMP_SECONDS; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.type.IpAddressType.IPADDRESS; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class TestFlatHashStrategy +{ + private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + private static final JoinCompiler JOIN_COMPILER = new JoinCompiler(TYPE_OPERATORS); + + @Test + public void testBatchedRawHashesMatchSinglePositionHashes() + { + List types = createTestingTypes(); + FlatHashStrategy flatHashStrategy = JOIN_COMPILER.getFlatHashStrategy(types); + + int positionCount = 1024; + Block[] blocks = new Block[types.size()]; + for (int i = 0; i < blocks.length; i++) { + blocks[i] = createRandomBlockForType(types.get(i), positionCount, 0.25f); + } + + long[] hashes = new long[positionCount]; + flatHashStrategy.hashBlocksBatched(blocks, hashes, 0, positionCount); + for (int position = 0; position < hashes.length; position++) { + long singleRowHash = flatHashStrategy.hash(blocks, position); + if (hashes[position] != singleRowHash) { + fail("Hash mismatch: %s <> %s at position %s - Values: %s".formatted(hashes[position], singleRowHash, position, singleRowTypesAndValues(types, blocks, position))); + } + } + // Ensure the formatting logic produces a real string and doesn't blow up since otherwise this code wouldn't be exercised + assertNotNull(singleRowTypesAndValues(types, blocks, 0)); + } + + private static List createTestingTypes() + { + List baseTypes = List.of( + BIGINT, + BOOLEAN, + createCharType(5), + createDecimalType(18), + createDecimalType(38), + DOUBLE, + INTEGER, + IPADDRESS, + REAL, + TIMESTAMP_SECONDS, + TIMESTAMP_MILLIS, + TIMESTAMP_MICROS, + TIMESTAMP_NANOS, + TIMESTAMP_PICOS, + UUID, + VARBINARY, + VARCHAR); + + ImmutableList.Builder builder = ImmutableList.builder(); + builder.addAll(baseTypes); + builder.add(RowType.anonymous(baseTypes)); + for (Type baseType : baseTypes) { + builder.add(new ArrayType(baseType)); + builder.add(new MapType(baseType, baseType, TYPE_OPERATORS)); + } + return builder.build(); + } + + private static String singleRowTypesAndValues(List types, Block[] blocks, int position) + { + ConnectorSession connectorSession = TestingSession.testSessionBuilder().build().toConnectorSession(); + StringBuilder builder = new StringBuilder(); + int column = 0; + for (Type type : types) { + builder.append("\n\t"); + builder.append(type); + builder.append(": "); + builder.append(type.getObjectValue(connectorSession, blocks[column], position)); + column++; + } + return builder.toString(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java b/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java index a01b11145539..42f41023daae 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java @@ -14,11 +14,12 @@ package io.trino.operator; import io.trino.spi.Page; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; -import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.LazyBlock; import org.junit.jupiter.api.Test; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import static io.trino.block.BlockAssertions.createIntsBlock; @@ -50,17 +51,20 @@ public void testRecordMaterializedBytes() public void testNestedBlocks() { Block elements = lazyWrapper(createIntsBlock(1, 2, 3)); - Block dictBlock = DictionaryBlock.create(2, elements, new int[] {0, 0}); - Page page = new Page(2, dictBlock); + Block arrayBlock = ArrayBlock.fromElementBlock(2, Optional.empty(), new int[] {0, 1, 3}, elements); + long initialArraySize = arrayBlock.getSizeInBytes(); + Page page = new Page(2, arrayBlock); AtomicLong sizeInBytes = new AtomicLong(); recordMaterializedBytes(page, sizeInBytes::getAndAdd); - assertEquals(sizeInBytes.get(), dictBlock.getSizeInBytes()); + assertEquals(arrayBlock.getSizeInBytes(), initialArraySize); + assertEquals(sizeInBytes.get(), arrayBlock.getSizeInBytes()); // dictionary block caches size in bytes - dictBlock.getLoadedBlock(); - assertEquals(sizeInBytes.get(), dictBlock.getSizeInBytes() + elements.getSizeInBytes()); + arrayBlock.getLoadedBlock(); + assertEquals(sizeInBytes.get(), arrayBlock.getSizeInBytes()); + assertEquals(sizeInBytes.get(), initialArraySize + elements.getSizeInBytes()); } private static LazyBlock lazyWrapper(Block block) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 1a106a41c949..0d2561864be8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -42,6 +42,7 @@ import java.util.Optional; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; +import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; @@ -54,15 +55,28 @@ public class TestAccumulatorCompiler { @Test public void testAccumulatorCompilerForTypeSpecificObjectParameter() + { + testAccumulatorCompilerForTypeSpecificObjectParameter(true); + testAccumulatorCompilerForTypeSpecificObjectParameter(false); + } + + private void testAccumulatorCompilerForTypeSpecificObjectParameter(boolean specializedLoops) { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); - assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class); + assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class, specializedLoops); } @Test public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader() throws Exception + { + testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(true); + testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(false); + } + + private void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(boolean specializedLoops) + throws Exception { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); @@ -80,10 +94,10 @@ public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLo assertThat(aggregation.getCanonicalName()).isEqualTo(LongTimestampAggregation.class.getCanonicalName()); assertThat(aggregation).isNotSameAs(LongTimestampAggregation.class); - assertGenerateAccumulator(aggregation, stateInterface); + assertGenerateAccumulator(aggregation, stateInterface, specializedLoops); } - private static void assertGenerateAccumulator(Class aggregation, Class stateInterface) + private static void assertGenerateAccumulator(Class aggregation, Class stateInterface, boolean specializedLoops) { AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(stateInterface); AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateInterface); @@ -105,7 +119,7 @@ private static void assertGenerateAccumulator(Cl FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); // test if we can compile aggregation - AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, implementation, functionNullability); + AccumulatorFactory accumulatorFactory = generateAccumulatorFactory(signature, implementation, functionNullability, specializedLoops); assertThat(accumulatorFactory).isNotNull(); // compile window aggregation diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java new file mode 100644 index 000000000000..53d55a59df6d --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.SqlNullable; +import io.trino.spi.function.SqlType; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestAggregationLoopBuilder +{ + private static final MethodHandle INPUT_FUNCTION; + private static final Object LAMBDA_A = "lambda a"; + private static final Object LAMBDA_B = 1234L; + + static { + try { + INPUT_FUNCTION = lookup().findStatic( + TestAggregationLoopBuilder.class, + "input", + methodType(void.class, InvocationList.class, ValueBlock.class, int.class, ValueBlock.class, int.class, Object.class, Object.class)); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + private MethodHandle loop; + private List keyBlocks; + private List valueBlocks; + + @BeforeClass + public void setUp() + throws ReflectiveOperationException + { + loop = buildLoop(INPUT_FUNCTION, 1, 2, false); + + ValueBlock keyBasic = new IntArrayBlock(5, Optional.empty(), new int[] {10, 11, 12, 13, 14}); + ValueBlock keyRleValue = new IntArrayBlock(1, Optional.empty(), new int[] {33}); + ValueBlock keyDictionary = new IntArrayBlock(3, Optional.empty(), new int[] {55, 54, 53}); + + keyBlocks = ImmutableList.builder() + .add(new TestParameter(keyBasic, keyBasic, new int[] {0, 1, 2, 3, 4})) + .add(new TestParameter(RunLengthEncodedBlock.create(keyRleValue, 5), keyRleValue, new int[] {0, 0, 0, 0, 0})) + .add(new TestParameter(DictionaryBlock.create(7, keyDictionary, new int[] {9, 9, 2, 1, 0, 1, 2}).getRegion(2, 5), keyDictionary, new int[] {2, 1, 0, 1, 2})) + .build(); + + ValueBlock valueBasic = new IntArrayBlock(5, Optional.empty(), new int[] {10, 11, 12, 13, 14}); + ValueBlock valueRleValue = new IntArrayBlock(1, Optional.empty(), new int[] {44}); + ValueBlock valueDictionary = new IntArrayBlock(3, Optional.empty(), new int[] {66, 65, 64}); + + valueBlocks = ImmutableList.builder() + .add(new TestParameter(valueBasic, valueBasic, new int[] {0, 1, 2, 3, 4})) + .add(new TestParameter(RunLengthEncodedBlock.create(valueRleValue, 5), valueRleValue, new int[] {0, 0, 0, 0, 0})) + .add(new TestParameter(DictionaryBlock.create(7, valueDictionary, new int[] {9, 9, 0, 1, 2, 1, 0}).getRegion(2, 5), valueDictionary, new int[] {0, 1, 2, 1, 0})) + .build(); + } + + @Test + public void testSelectAll() + throws Throwable + { + AggregationMask mask = AggregationMask.createSelectAll(5); + for (TestParameter keyBlock : keyBlocks) { + for (TestParameter valueBlock : valueBlocks) { + InvocationList invocationList = new InvocationList(); + loop.invokeExact(mask, invocationList, keyBlock.inputBlock(), valueBlock.inputBlock(), LAMBDA_A, LAMBDA_B); + assertThat(invocationList.getInvocations()).isEqualTo(buildExpectedInvocation(keyBlock, valueBlock, mask).getInvocations()); + } + } + } + + @Test + public void testMasked() + throws Throwable + { + AggregationMask mask = AggregationMask.createSelectedPositions(5, new int[] {1, 2, 4}, 3); + for (TestParameter keyBlock : keyBlocks) { + for (TestParameter valueBlock : valueBlocks) { + InvocationList invocationList = new InvocationList(); + loop.invokeExact(mask, invocationList, keyBlock.inputBlock(), valueBlock.inputBlock(), LAMBDA_A, LAMBDA_B); + assertThat(invocationList.getInvocations()).isEqualTo(buildExpectedInvocation(keyBlock, valueBlock, mask).getInvocations()); + } + } + } + + private static InvocationList buildExpectedInvocation(TestParameter keyBlock, TestParameter valueBlock, AggregationMask mask) + { + InvocationList invocationList = new InvocationList(); + int[] keyPositions = keyBlock.invokedPositions(); + int[] valuePositions = valueBlock.invokedPositions(); + if (mask.isSelectAll()) { + for (int position = 0; position < keyPositions.length; position++) { + invocationList.add(keyBlock.invokedBlock(), keyPositions[position], valueBlock.invokedBlock(), valuePositions[position], LAMBDA_A, LAMBDA_B); + } + } + else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < mask.getSelectedPositionCount(); i++) { + int position = selectedPositions[i]; + invocationList.add(keyBlock.invokedBlock(), keyPositions[position], valueBlock.invokedBlock(), valuePositions[position], LAMBDA_A, LAMBDA_B); + } + } + return invocationList; + } + + @SuppressWarnings("UnusedVariable") + private record TestParameter(Block inputBlock, ValueBlock invokedBlock, int[] invokedPositions) {} + + public static void input( + @AggregationState InvocationList invocationList, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition, + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + Object lambdaA, + Object lambdaB) + { + invocationList.add(keyBlock, keyPosition, valueBlock, valuePosition, lambdaA, lambdaB); + } + + public static class InvocationList + { + private final List invocations = new ArrayList<>(); + + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition, Object lambdaA, Object lambdaB) + { + invocations.add(new Invocation(keyBlock, keyPosition, valueBlock, valuePosition, lambdaA, lambdaB)); + } + + public List getInvocations() + { + return ImmutableList.copyOf(invocations); + } + + public record Invocation(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition, Object lambdaA, Object lambdaB) {} + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java index a269c53b08b4..835f0f3fb94e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java @@ -18,6 +18,7 @@ import io.trino.operator.aggregation.state.NullableLongState; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -73,7 +74,7 @@ public static final class CountNull private CountNull() {} @InputFunction - public static void input(@AggregationState NullableLongState state, @BlockPosition @SqlNullable @SqlType(StandardTypes.BIGINT) Block block, @BlockIndex int position) + public static void input(@AggregationState NullableLongState state, @BlockPosition @SqlNullable @SqlType(StandardTypes.BIGINT) ValueBlock block, @BlockIndex int position) { if (block.isNull(position)) { state.setValue(state.getValue() + 1); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java index c960a3c3462e..4554352ed605 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java @@ -17,6 +17,7 @@ import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; @@ -233,7 +234,7 @@ private static void addToState(DecimalType type, LongDecimalWithOverflowAndLongS else { BlockBuilder blockBuilder = type.createFixedSizeBlockBuilder(1); type.writeObject(blockBuilder, Int128.valueOf(value)); - DecimalAverageAggregation.inputLongDecimal(state, blockBuilder.build(), 0); + DecimalAverageAggregation.inputLongDecimal(state, (Int128ArrayBlock) blockBuilder.buildValueBlock(), 0); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java index 7689ab2a3f8d..66ead07005fc 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java @@ -16,6 +16,7 @@ import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateFactory; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; @@ -142,7 +143,7 @@ private static void addToState(LongDecimalWithOverflowState state, BigInteger va else { BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1); TYPE.writeObject(blockBuilder, Int128.valueOf(value)); - DecimalSumAggregation.inputLongDecimal(state, blockBuilder.build(), 0); + DecimalSumAggregation.inputLongDecimal(state, (Int128ArrayBlock) blockBuilder.buildValueBlock(), 0); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java index 24a739c1a597..dcb81d328e2f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -27,8 +28,8 @@ import java.util.stream.IntStream; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -58,15 +59,15 @@ private static void testMassive(boolean grouped, Type type, ObjIntConsumer IntStream.iterate(value, IntUnaryOperator.identity()).limit(value)) .forEach(value -> writeData.accept(inputBlockBuilder, value)); - Block inputBlock = inputBlockBuilder.build(); + ValueBlock inputBlock = inputBlockBuilder.buildValueBlock(); TypedHistogram typedHistogram = new TypedHistogram( type, TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)), - TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)), + TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)), TYPE_OPERATORS.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)), - TYPE_OPERATORS.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)), - TYPE_OPERATORS.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)), + TYPE_OPERATORS.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)), + TYPE_OPERATORS.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)), grouped); int groupId = 0; diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index 57253a903202..63ce7741b279 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -51,7 +51,7 @@ public TestingAggregationFunction(BoundSignature signature, FunctionNullability .collect(toImmutableList()); intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); this.finalType = signature.getReturnType(); - this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability); + this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability, true); distinctFactory = new DistinctAccumulatorFactory( factory, parameterTypes, diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java index a70a662e7c0d..5cc028271641 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java @@ -16,7 +16,7 @@ import io.airlift.slice.Slice; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.sql.analyzer.TypeSignatureProvider; import org.junit.jupiter.api.Test; @@ -46,17 +46,17 @@ public void testInputEmptyState() SingleListaggAggregationState state = new SingleListaggAggregationState(); String s = "value1"; - Block value = createStringsBlock(s); + ValueBlock value = createStringsBlock(s); Slice separator = utf8Slice(","); Slice overflowFiller = utf8Slice("..."); ListaggAggregationFunction.input( state, value, + 0, separator, false, overflowFiller, - true, - 0); + true); VariableWidthBlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 1); state.write(blockBuilder); @@ -74,11 +74,11 @@ public void testInputOverflowOverflowFillerTooLong() assertThatThrownBy(() -> ListaggAggregationFunction.input( state, createStringsBlock("value1"), + 0, utf8Slice(","), false, utf8Slice(overflowFillerTooLong), - false, - 0)) + false)) .isInstanceOf(TrinoException.class) .matches(throwable -> ((TrinoException) throwable).getErrorCode() == INVALID_FUNCTION_ARGUMENT.toErrorCode()); } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java index ea89f880d812..0e6acb76d8ca 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeUtils; @@ -34,8 +34,8 @@ import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -151,32 +151,32 @@ private static void test(Type keyType, Type valueType, List> private static void test(Type keyType, Type valueType, boolean min, List> testData, Comparator comparator, int capacity) { MethodHandle keyReadFlat = TYPE_OPERATORS.getReadValueOperator(keyType, simpleConvention(BLOCK_BUILDER, FLAT)); - MethodHandle keyWriteFlat = TYPE_OPERATORS.getReadValueOperator(keyType, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + MethodHandle keyWriteFlat = TYPE_OPERATORS.getReadValueOperator(keyType, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle valueReadFlat = TYPE_OPERATORS.getReadValueOperator(valueType, simpleConvention(BLOCK_BUILDER, FLAT)); - MethodHandle valueWriteFlat = TYPE_OPERATORS.getReadValueOperator(valueType, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + MethodHandle valueWriteFlat = TYPE_OPERATORS.getReadValueOperator(valueType, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle comparisonFlatFlat; MethodHandle comparisonFlatBlock; if (min) { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedLastOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); } else { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); comparator = comparator.reversed(); } - Block expected = toBlock(valueType, testData.stream() + ValueBlock expected = toBlock(valueType, testData.stream() .sorted(comparing(Entry::key, comparator)) .map(Entry::value) .limit(capacity) .toList()); - Block inputKeys = toBlock(keyType, testData.stream().map(Entry::key).toList()); - Block inputValues = toBlock(valueType, testData.stream().map(Entry::value).toList()); + ValueBlock inputKeys = toBlock(keyType, testData.stream().map(Entry::key).toList()); + ValueBlock inputValues = toBlock(valueType, testData.stream().map(Entry::value).toList()); // verify basic build TypedKeyValueHeap heap = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); - heap.addAll(inputKeys, inputValues); + getAddAll(heap, inputKeys, inputValues); assertEqual(heap, valueType, expected); // verify copy constructor @@ -185,44 +185,47 @@ private static void test(Type keyType, Type valueType, boolean min, List< // build in two parts and merge together TypedKeyValueHeap part1 = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); int splitPoint = inputKeys.getPositionCount() / 2; - part1.addAll( - inputKeys.getRegion(0, splitPoint), - inputValues.getRegion(0, splitPoint)); + getAddAll(part1, inputKeys.getRegion(0, splitPoint), inputValues.getRegion(0, splitPoint)); BlockBuilder part1KeyBlockBuilder = keyType.createBlockBuilder(null, part1.getCapacity()); BlockBuilder part1ValueBlockBuilder = valueType.createBlockBuilder(null, part1.getCapacity()); part1.writeAllUnsorted(part1KeyBlockBuilder, part1ValueBlockBuilder); - Block part1KeyBlock = part1KeyBlockBuilder.build(); - Block part1ValueBlock = part1ValueBlockBuilder.build(); + ValueBlock part1KeyBlock = part1KeyBlockBuilder.buildValueBlock(); + ValueBlock part1ValueBlock = part1ValueBlockBuilder.buildValueBlock(); TypedKeyValueHeap part2 = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); - part2.addAll( - inputKeys.getRegion(splitPoint, inputKeys.getPositionCount() - splitPoint), - inputValues.getRegion(splitPoint, inputValues.getPositionCount() - splitPoint)); + getAddAll(part2, inputKeys.getRegion(splitPoint, inputKeys.getPositionCount() - splitPoint), inputValues.getRegion(splitPoint, inputValues.getPositionCount() - splitPoint)); BlockBuilder part2KeyBlockBuilder = keyType.createBlockBuilder(null, part2.getCapacity()); BlockBuilder part2ValueBlockBuilder = valueType.createBlockBuilder(null, part2.getCapacity()); part2.writeAllUnsorted(part2KeyBlockBuilder, part2ValueBlockBuilder); - Block part2KeyBlock = part2KeyBlockBuilder.build(); - Block part2ValueBlock = part2ValueBlockBuilder.build(); + ValueBlock part2KeyBlock = part2KeyBlockBuilder.buildValueBlock(); + ValueBlock part2ValueBlock = part2ValueBlockBuilder.buildValueBlock(); TypedKeyValueHeap merged = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); - merged.addAll(part1KeyBlock, part1ValueBlock); - merged.addAll(part2KeyBlock, part2ValueBlock); + getAddAll(merged, part1KeyBlock, part1ValueBlock); + getAddAll(merged, part2KeyBlock, part2ValueBlock); assertEqual(merged, valueType, expected); } - private static void assertEqual(TypedKeyValueHeap heap, Type valueType, Block expected) + private static void getAddAll(TypedKeyValueHeap heap, ValueBlock inputKeys, ValueBlock inputValues) + { + for (int i = 0; i < inputKeys.getPositionCount(); i++) { + heap.add(inputKeys, i, inputValues, i); + } + } + + private static void assertEqual(TypedKeyValueHeap heap, Type valueType, ValueBlock expected) { BlockBuilder resultBlockBuilder = valueType.createBlockBuilder(null, OUTPUT_SIZE); heap.writeValuesSorted(resultBlockBuilder); - Block actual = resultBlockBuilder.build(); + ValueBlock actual = resultBlockBuilder.buildValueBlock(); assertBlockEquals(valueType, actual, expected); } - private static Block toBlock(Type type, List inputStream) + private static ValueBlock toBlock(Type type, List inputStream) { BlockBuilder blockBuilder = type.createBlockBuilder(null, INPUT_SIZE); inputStream.forEach(value -> TypeUtils.writeNativeValue(type, blockBuilder, value)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } // TODO remove this suppression when the error prone checker actually supports records correctly diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java index ecba7480d853..fe5b3995c55a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeUtils; @@ -30,8 +30,8 @@ import java.util.stream.IntStream; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -95,26 +95,26 @@ private static void test(Type type, List testData, Comparator comparat private static void test(Type type, boolean min, List testData, Comparator comparator) { MethodHandle readFlat = TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)); - MethodHandle writeFlat = TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + MethodHandle writeFlat = TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle comparisonFlatFlat; MethodHandle comparisonFlatBlock; if (min) { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); } else { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); comparator = comparator.reversed(); } - Block expected = toBlock(type, testData.stream().sorted(comparator).limit(OUTPUT_SIZE).toList()); - Block inputData = toBlock(type, testData); + ValueBlock expected = toBlock(type, testData.stream().sorted(comparator).limit(OUTPUT_SIZE).toList()); + ValueBlock inputData = toBlock(type, testData); // verify basic build TypedHeap heap = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - heap.addAll(inputData); + addAll(heap, inputData); assertEqual(heap, type, expected); // verify copy constructor @@ -122,35 +122,42 @@ private static void test(Type type, boolean min, List testData, Comparato // build in two parts and merge together TypedHeap part1 = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - part1.addAll(inputData.getRegion(0, inputData.getPositionCount() / 2)); + addAll(part1, inputData.getRegion(0, inputData.getPositionCount() / 2)); BlockBuilder part1BlockBuilder = type.createBlockBuilder(null, part1.getCapacity()); part1.writeAllUnsorted(part1BlockBuilder); - Block part1Block = part1BlockBuilder.build(); + ValueBlock part1Block = part1BlockBuilder.buildValueBlock(); TypedHeap part2 = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - part2.addAll(inputData.getRegion(inputData.getPositionCount() / 2, inputData.getPositionCount() - (inputData.getPositionCount() / 2))); + addAll(part2, inputData.getRegion(inputData.getPositionCount() / 2, inputData.getPositionCount() - (inputData.getPositionCount() / 2))); BlockBuilder part2BlockBuilder = type.createBlockBuilder(null, part2.getCapacity()); part2.writeAllUnsorted(part2BlockBuilder); - Block part2Block = part2BlockBuilder.build(); + ValueBlock part2Block = part2BlockBuilder.buildValueBlock(); TypedHeap merged = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - merged.addAll(part1Block); - merged.addAll(part2Block); + addAll(merged, part1Block); + addAll(merged, part2Block); assertEqual(merged, type, expected); } - private static void assertEqual(TypedHeap heap, Type type, Block expected) + private static void addAll(TypedHeap heap, ValueBlock inputData) + { + for (int i = 0; i < inputData.getPositionCount(); i++) { + heap.add(inputData, i); + } + } + + private static void assertEqual(TypedHeap heap, Type type, ValueBlock expected) { BlockBuilder resultBlockBuilder = type.createBlockBuilder(null, OUTPUT_SIZE); heap.writeAllSorted(resultBlockBuilder); - Block actual = resultBlockBuilder.build(); + ValueBlock actual = resultBlockBuilder.buildValueBlock(); assertBlockEquals(type, actual, expected); } - private static Block toBlock(Type type, List inputStream) + private static ValueBlock toBlock(Type type, List inputStream) { BlockBuilder blockBuilder = type.createBlockBuilder(null, INPUT_SIZE); inputStream.forEach(value -> TypeUtils.writeNativeValue(type, blockBuilder, value)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java b/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java index 8c9077e03a06..95247484c0f8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -100,7 +101,7 @@ public void testDifferentPositions() JoinProbe probe = joinProbeFactory.createJoinProbe(page); Page output = lookupJoinPageBuilder.build(probe); assertEquals(output.getChannelCount(), 2); - assertTrue(output.getBlock(0) instanceof DictionaryBlock); + assertTrue(output.getBlock(0) instanceof LongArrayBlock); assertEquals(output.getPositionCount(), 0); lookupJoinPageBuilder.reset(); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java index 184540e20f01..d2214565dc7b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java @@ -21,6 +21,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -101,7 +102,7 @@ public void testDifferentPositions() JoinProbe probe = joinProbeFactory.createJoinProbe(page, lookupSource); Page output = lookupJoinPageBuilder.build(probe); assertEquals(output.getChannelCount(), 2); - assertTrue(output.getBlock(0) instanceof DictionaryBlock); + assertTrue(output.getBlock(0) instanceof LongArrayBlock); assertEquals(output.getPositionCount(), 0); lookupJoinPageBuilder.reset(); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java index 491949cbf6fe..7dee87dfe24d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java @@ -459,7 +459,7 @@ private static ImmutableList getTypes() IPADDRESS); } - private Block createBlockForType(Type type, int positionsPerPage) + private static Block createBlockForType(Type type, int positionsPerPage) { return createRandomBlockForType(type, positionsPerPage, 0.2F); } @@ -707,7 +707,7 @@ public Stream getEnqueuedDeserialized(int partition) public List getEnqueued(int partition) { Collection serializedPages = enqueued.get(partition); - return serializedPages == null ? ImmutableList.of() : ImmutableList.copyOf(serializedPages); + return ImmutableList.copyOf(serializedPages); } public void throwOnEnqueue(RuntimeException throwOnEnqueue) @@ -813,31 +813,20 @@ public Optional getFailureCause() } } - private static class SumModuloPartitionFunction + private record SumModuloPartitionFunction(int partitionCount, int... hashChannels) implements PartitionFunction { - private final int[] hashChannels; - private final int partitionCount; - - SumModuloPartitionFunction(int partitionCount, int... hashChannels) + private SumModuloPartitionFunction { checkArgument(partitionCount > 0); - this.partitionCount = partitionCount; - this.hashChannels = hashChannels; - } - - @Override - public int getPartitionCount() - { - return partitionCount; } @Override public int getPartition(Page page, int position) { long value = 0; - for (int i = 0; i < hashChannels.length; i++) { - value += page.getBlock(hashChannels[i]).getLong(position, 0); + for (int hashChannel : hashChannels) { + value += page.getBlock(hashChannel).getLong(position, 0); } return toIntExact(Math.abs(value) % partitionCount); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java index 2f5760112780..c25649d3edf2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java @@ -24,6 +24,7 @@ import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.BigintType; @@ -90,8 +91,6 @@ public void testMixedBlockTypes(TestType type) List input = ImmutableList.of( input(emptyBlock(type)), input(nullBlock(type, 3), 0, 2), - input(nullBlock(TestType.UNKNOWN, 3), 0, 2), // a := null projections are handled by UnknownType null block - input(nullBlock(TestType.UNKNOWN, 1), 0), // a := null projections are handled by UnknownType null block, 1 position uses non RLE block input(notNullBlock(type, 3), 1, 2), input(partiallyNullBlock(type, 4), 0, 1, 2, 3), input(partiallyNullBlock(type, 4)), // empty position list @@ -169,7 +168,7 @@ public static Object[][] differentValues() {TestType.INTEGER, createIntsBlock(0), createIntsBlock(1)}, {TestType.CHAR_10, createStringsBlock("0"), createStringsBlock("1")}, {TestType.VARCHAR, createStringsBlock("0"), createStringsBlock("1")}, - {TestType.DOUBLE, createDoublesBlock(0D), createDoublesBlock(1D)}, + {TestType.DOUBLE, createDoublesBlock(0.0), createDoublesBlock(1.0)}, {TestType.SMALLINT, createSmallintsBlock(0), createSmallintsBlock(1)}, {TestType.TINYINT, createTinyintsBlock(0), createTinyintsBlock(1)}, {TestType.VARBINARY, createSlicesBlock(Slices.allocate(Long.BYTES)), createSlicesBlock(Slices.allocate(Long.BYTES).getOutput().appendLong(1).slice())}, @@ -184,7 +183,7 @@ public static Object[][] differentValues() @Test(dataProvider = "types") public void testMultipleRleWithTheSameValueProduceRle(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); Block value = notNullBlock(type, 1); positionsAppender.append(allPositions(3), rleBlock(value, 3)); @@ -199,7 +198,7 @@ public void testMultipleRleWithTheSameValueProduceRle(TestType type) public void testRleAppendForComplexTypeWithNullElement(TestType type, Block value) { checkArgument(value.getPositionCount() == 1); - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); positionsAppender.append(allPositions(3), rleBlock(value, 3)); positionsAppender.append(allPositions(2), rleBlock(value, 2)); @@ -213,7 +212,7 @@ public void testRleAppendForComplexTypeWithNullElement(TestType type, Block valu @Test(dataProvider = "types") public void testRleAppendedWithSinglePositionDoesNotProduceRle(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); Block value = notNullBlock(type, 1); positionsAppender.append(allPositions(3), rleBlock(value, 3)); @@ -226,16 +225,16 @@ public void testRleAppendedWithSinglePositionDoesNotProduceRle(TestType type) } @Test(dataProvider = "types") - public void testMultipleTheSameDictionariesProduceDictionary(TestType type) + public static void testMultipleTheSameDictionariesProduceDictionary(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); // test if appender can accept different dictionary after a build testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); } - private void testMultipleTheSameDictionariesProduceDictionary(TestType type, PositionsAppender positionsAppender) + private static void testMultipleTheSameDictionariesProduceDictionary(TestType type, UnnestingPositionsAppender positionsAppender) { Block dictionary = createRandomBlockForType(type, 4, 0); positionsAppender.append(allPositions(3), createRandomDictionaryBlock(dictionary, 3)); @@ -279,11 +278,11 @@ public void testDictionarySingleThenFlat(TestType type) { BlockView firstInput = input(dictionaryBlock(type, 1, 4, 0), 0); BlockView secondInput = input(dictionaryBlock(type, 2, 4, 0), 0, 1); - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - firstInput.getPositions().forEach((int position) -> positionsAppender.append(position, firstInput.getBlock())); - positionsAppender.append(secondInput.getPositions(), secondInput.getBlock()); + firstInput.positions().forEach((int position) -> positionsAppender.append(position, firstInput.block())); + positionsAppender.append(secondInput.positions(), secondInput.block()); assertBuildResult(type, ImmutableList.of(firstInput, secondInput), positionsAppender, initialRetainedSize); } @@ -291,7 +290,7 @@ public void testDictionarySingleThenFlat(TestType type) @Test(dataProvider = "types") public void testConsecutiveBuilds(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // empty block positionsAppender.append(positions(), emptyBlock(type)); @@ -329,17 +328,17 @@ public void testConsecutiveBuilds(TestType type) } // testcase for jit bug described https://github.com/trinodb/trino/issues/12821. - // this test needs to be run first (hence lowest priority) as order of tests - // influence jit compilation making this problem to not occur if other tests are run first. + // this test needs to be run first (hence the lowest priority) as the test order + // influences jit compilation, making this problem to not occur if other tests are run first. @Test(priority = Integer.MIN_VALUE) public void testSliceRle() { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(VARCHAR, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(VARCHAR, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // first append some not empty value to avoid RleAwarePositionsAppender for the empty value positionsAppender.appendRle(singleValueBlock("some value"), 1); // append empty value multiple times to trigger jit compilation - Block emptyStringBlock = singleValueBlock(""); + ValueBlock emptyStringBlock = singleValueBlock(""); for (int i = 0; i < 1000; i++) { positionsAppender.appendRle(emptyStringBlock, 2000); } @@ -355,7 +354,7 @@ public void testRowWithNestedFields() rleBlock(TestType.VARCHAR, 2) }); - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); positionsAppender.append(allPositions(2), rowBLock); Block actual = positionsAppender.build(); @@ -375,24 +374,24 @@ public static Object[][] complexTypesWithNullElementBlock() public static Object[][] types() { return Arrays.stream(TestType.values()) - .filter(testType -> !testType.equals(TestType.UNKNOWN)) + .filter(testType -> testType != TestType.UNKNOWN) .map(type -> new Object[] {type}) .toArray(Object[][]::new); } - private static Block singleValueBlock(String value) + private static ValueBlock singleValueBlock(String value) { BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 1); VARCHAR.writeSlice(blockBuilder, Slices.utf8Slice(value)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } - private IntArrayList allPositions(int count) + private static IntArrayList allPositions(int count) { return new IntArrayList(IntStream.range(0, count).toArray()); } - private BlockView input(Block block, int... positions) + private static BlockView input(Block block, int... positions) { return new BlockView(block, new IntArrayList(positions)); } @@ -402,53 +401,53 @@ private static IntArrayList positions(int... positions) return new IntArrayList(positions); } - private Block dictionaryBlock(Block dictionary, int positionCount) + private static Block dictionaryBlock(Block dictionary, int positionCount) { return createRandomDictionaryBlock(dictionary, positionCount); } - private Block dictionaryBlock(Block dictionary, int[] ids) + private static Block dictionaryBlock(Block dictionary, int[] ids) { return DictionaryBlock.create(ids.length, dictionary, ids); } - private Block dictionaryBlock(TestType type, int positionCount, int dictionarySize, float nullRate) + private static Block dictionaryBlock(TestType type, int positionCount, int dictionarySize, float nullRate) { Block dictionary = createRandomBlockForType(type, dictionarySize, nullRate); return createRandomDictionaryBlock(dictionary, positionCount); } - private RunLengthEncodedBlock rleBlock(Block value, int positionCount) + private static RunLengthEncodedBlock rleBlock(Block value, int positionCount) { checkArgument(positionCount >= 2); return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(value, positionCount); } - private RunLengthEncodedBlock rleBlock(TestType type, int positionCount) + private static RunLengthEncodedBlock rleBlock(TestType type, int positionCount) { checkArgument(positionCount >= 2); Block rleValue = createRandomBlockForType(type, 1, 0); return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(rleValue, positionCount); } - private RunLengthEncodedBlock nullRleBlock(TestType type, int positionCount) + private static RunLengthEncodedBlock nullRleBlock(TestType type, int positionCount) { checkArgument(positionCount >= 2); Block rleValue = nullBlock(type, 1); return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(rleValue, positionCount); } - private Block partiallyNullBlock(TestType type, int positionCount) + private static Block partiallyNullBlock(TestType type, int positionCount) { return createRandomBlockForType(type, positionCount, 0.5F); } - private Block notNullBlock(TestType type, int positionCount) + private static Block notNullBlock(TestType type, int positionCount) { return createRandomBlockForType(type, positionCount, 0); } - private Block nullBlock(TestType type, int positionCount) + private static Block nullBlock(TestType type, int positionCount) { BlockBuilder blockBuilder = type.getType().createBlockBuilder(null, positionCount); for (int i = 0; i < positionCount; i++) { @@ -466,19 +465,19 @@ private static Block nullBlock(Type type, int positionCount) return blockBuilder.build(); } - private Block emptyBlock(TestType type) + private static Block emptyBlock(TestType type) { return type.adapt(type.getType().createBlockBuilder(null, 0).build()); } - private Block createRandomBlockForType(TestType type, int positionCount, float nullRate) + private static Block createRandomBlockForType(TestType type, int positionCount, float nullRate) { return type.adapt(BlockAssertions.createRandomBlockForType(type.getType(), positionCount, nullRate)); } - private void testNullRle(Type type, Block source) + private static void testNullRle(Type type, Block source) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // extract null positions IntArrayList positions = new IntArrayList(source.getPositionCount()); for (int i = 0; i < source.getPositionCount(); i++) { @@ -495,22 +494,22 @@ private void testNullRle(Type type, Block source) assertInstanceOf(actual, RunLengthEncodedBlock.class); } - private void testAppend(TestType type, List inputs) + private static void testAppend(TestType type, List inputs) { testAppendBatch(type, inputs); testAppendSingle(type, inputs); } - private void testAppendBatch(TestType type, List inputs) + private static void testAppendBatch(TestType type, List inputs) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - inputs.forEach(input -> positionsAppender.append(input.getPositions(), input.getBlock())); + inputs.forEach(input -> positionsAppender.append(input.positions(), input.block())); assertBuildResult(type, inputs, positionsAppender, initialRetainedSize); } - private void assertBuildResult(TestType type, List inputs, PositionsAppender positionsAppender, long initialRetainedSize) + private static void assertBuildResult(TestType type, List inputs, UnnestingPositionsAppender positionsAppender, long initialRetainedSize) { long sizeInBytes = positionsAppender.getSizeInBytes(); assertGreaterThanOrEqual(positionsAppender.getRetainedSizeInBytes(), sizeInBytes); @@ -524,12 +523,12 @@ private void assertBuildResult(TestType type, List inputs, PositionsA assertEquals(secondBlock.getPositionCount(), 0); } - private void testAppendSingle(TestType type, List inputs) + private static void testAppendSingle(TestType type, List inputs) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - inputs.forEach(input -> input.getPositions().forEach((int position) -> positionsAppender.append(position, input.getBlock()))); + inputs.forEach(input -> input.positions().forEach((int position) -> positionsAppender.append(position, input.block()))); long sizeInBytes = positionsAppender.getSizeInBytes(); assertGreaterThanOrEqual(positionsAppender.getRetainedSizeInBytes(), sizeInBytes); Block actual = positionsAppender.build(); @@ -542,7 +541,7 @@ private void testAppendSingle(TestType type, List inputs) assertEquals(secondBlock.getPositionCount(), 0); } - private void assertBlockIsValid(Block actual, long sizeInBytes, Type type, List inputs) + private static void assertBlockIsValid(Block actual, long sizeInBytes, Type type, List inputs) { PageBuilderStatus pageBuilderStatus = new PageBuilderStatus(); BlockBuilderStatus blockBuilderStatus = pageBuilderStatus.createBlockBuilderStatus(); @@ -552,12 +551,12 @@ private void assertBlockIsValid(Block actual, long sizeInBytes, Type type, List< assertEquals(sizeInBytes, pageBuilderStatus.getSizeInBytes()); } - private Block buildBlock(Type type, List inputs, BlockBuilderStatus blockBuilderStatus) + private static Block buildBlock(Type type, List inputs, BlockBuilderStatus blockBuilderStatus) { BlockBuilder blockBuilder = type.createBlockBuilder(blockBuilderStatus, 10); for (BlockView input : inputs) { - for (int position : input.getPositions()) { - type.appendTo(input.getBlock(), position, blockBuilder); + for (int position : input.positions()) { + type.appendTo(input.block(), position, blockBuilder); } } return blockBuilder.build(); @@ -606,30 +605,12 @@ public Type getType() } } - private static class BlockView + private record BlockView(Block block, IntArrayList positions) { - private final Block block; - private final IntArrayList positions; - - private BlockView(Block block, IntArrayList positions) - { - this.block = requireNonNull(block, "block is null"); - this.positions = requireNonNull(positions, "positions is null"); - } - - public Block getBlock() - { - return block; - } - - public IntArrayList getPositions() - { - return positions; - } - - public void appendTo(PositionsAppender positionsAppender) + private BlockView { - positionsAppender.append(getPositions(), getBlock()); + requireNonNull(block, "block is null"); + requireNonNull(positions, "positions is null"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java b/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java index 3f00f88f3130..5a61bf221287 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java @@ -28,14 +28,14 @@ import static io.trino.spi.type.BigintType.BIGINT; import static org.assertj.core.api.Assertions.assertThat; -public class TestSkewedPartitionRebalancer +class TestSkewedPartitionRebalancer { private static final long MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(1, MEGABYTE).toBytes(); private static final long MIN_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(50, MEGABYTE).toBytes(); private static final int MAX_REBALANCED_PARTITIONS = 30; @Test - public void testRebalanceWithSkewness() + void testRebalanceWithSkewness() { int partitionCount = 3; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -51,7 +51,7 @@ public void testRebalanceWithSkewness() rebalancer.addPartitionRowCount(1, 1000); rebalancer.addPartitionRowCount(2, 1000); rebalancer.addDataProcessed(DataSize.of(40, MEGABYTE).toBytes()); - // No rebalancing will happen since data processed is less than 50MB limit + // No rebalancing will happen since the data processed is less than 50MB rebalancer.rebalance(); assertThat(getPartitionPositions(function, 17)) @@ -96,7 +96,7 @@ public void testRebalanceWithSkewness() } @Test - public void testRebalanceWithoutSkewness() + void testRebalanceWithoutSkewness() { int partitionCount = 6; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -128,7 +128,7 @@ public void testRebalanceWithoutSkewness() } @Test - public void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() + void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() { int partitionCount = 3; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -157,7 +157,7 @@ public void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() } @Test - public void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingMinDataProcessed() + void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingMinDataProcessed() { int partitionCount = 3; long minPartitionDataProcessedRebalanceThreshold = DataSize.of(50, MEGABYTE).toBytes(); @@ -187,7 +187,7 @@ public void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingM } @Test - public void testRebalancePartitionToSingleTaskInARebalancingLoop() + void testRebalancePartitionToSingleTaskInARebalancingLoop() { int partitionCount = 3; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -204,7 +204,7 @@ public void testRebalancePartitionToSingleTaskInARebalancingLoop() rebalancer.addPartitionRowCount(2, 0); rebalancer.addDataProcessed(DataSize.of(60, MEGABYTE).toBytes()); - // rebalancing will only happen to single task even though two tasks are available + // rebalancing will only happen to a single task even though two tasks are available rebalancer.rebalance(); assertThat(getPartitionPositions(function, 17)) @@ -344,10 +344,10 @@ public void testRebalancePartitionWithMaxRebalancedPartitionsPerTask() .containsExactly(ImmutableList.of(0, 1), ImmutableList.of(1, 0), ImmutableList.of(2)); } - private List> getPartitionPositions(PartitionFunction function, int maxPosition) + private static List> getPartitionPositions(PartitionFunction function, int maxPosition) { List> partitionPositions = new ArrayList<>(); - for (int partition = 0; partition < function.getPartitionCount(); partition++) { + for (int partition = 0; partition < function.partitionCount(); partition++) { partitionPositions.add(new ArrayList<>()); } @@ -364,22 +364,9 @@ private static Page dummyPage() return SequencePageBuilder.createSequencePage(ImmutableList.of(BIGINT), 100, 0); } - private static class TestPartitionFunction + private record TestPartitionFunction(int partitionCount) implements PartitionFunction { - private final int partitionCount; - - private TestPartitionFunction(int partitionCount) - { - this.partitionCount = partitionCount; - } - - @Override - public int getPartitionCount() - { - return partitionCount; - } - @Override public int getPartition(Page page, int position) { diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java index f9470fe93eaf..c90ffe5234b0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -34,7 +35,7 @@ public void testAppendEmptySliceRle() { // test SlicePositionAppender.appendRle with empty value (Slice with length 0) PositionsAppender positionsAppender = new SlicePositionsAppender(1, 100); - Block value = createStringsBlock(""); + ValueBlock value = createStringsBlock(""); positionsAppender.appendRle(value, 10); Block actualBlock = positionsAppender.build(); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java index 0181173f5903..40534ddbd362 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java @@ -15,7 +15,7 @@ import io.airlift.slice.Slice; import io.trino.metadata.InternalFunctionBundle; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; import io.trino.spi.function.ScalarFunction; @@ -51,6 +51,7 @@ public void init() assertions = new QueryAssertions(); assertions.addFunctions(InternalFunctionBundle.builder() .scalar(FunctionWithBlockAndPositionConvention.class) + .scalar(FunctionWithValueBlockAndPositionConvention.class) .build()); } @@ -105,7 +106,7 @@ public static Object generic(@TypeParameter("E") Type type, @SqlNullable @SqlTyp @TypeParameter("E") @SqlNullable @SqlType("E") - public static Object generic(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType("E") Block block, @BlockIndex int position) + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType("E") ValueBlock block, @BlockIndex int position) { hitBlockPositionObject.set(true); return readNativeValue(type, block, position); @@ -124,7 +125,7 @@ public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @TypeParameter("E") @SqlNullable @SqlType("E") - public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) Block block, @BlockIndex int position) + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionSlice.set(true); return type.getSlice(block, position); @@ -141,7 +142,7 @@ public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNul @TypeParameter("E") @SqlNullable @SqlType("E") - public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) Block block, @BlockIndex int position) + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionBoolean.set(true); return type.getBoolean(block, position); @@ -158,7 +159,7 @@ public static Long getLong(@SqlNullable @SqlType(StandardTypes.BIGINT) Long numb @SqlType(StandardTypes.BIGINT) @SqlNullable - public static Long getBlockPosition(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block block, @BlockIndex int position) + public static Long getBlockPosition(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionBigint.set(true); return BIGINT.getLong(block, position); @@ -173,7 +174,126 @@ public static Double getDouble(@SqlNullable @SqlType(StandardTypes.DOUBLE) Doubl @SqlType(StandardTypes.DOUBLE) @SqlNullable - public static Double getDouble(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block block, @BlockIndex int position) + public static Double getDouble(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionDouble.set(true); + return DOUBLE.getDouble(block, position); + } + } + + @Test + public void testValueBlockPosition() + { + assertThat(assertions.function("test_value_block_position", "BIGINT '1234'")) + .isEqualTo(1234L); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionBigint.get()); + + assertThat(assertions.function("test_value_block_position", "12.34e0")) + .isEqualTo(12.34); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionDouble.get()); + + assertThat(assertions.function("test_value_block_position", "'hello'")) + .hasType(createVarcharType(5)) + .isEqualTo("hello"); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionSlice.get()); + + assertThat(assertions.function("test_value_block_position", "true")) + .isEqualTo(true); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionBoolean.get()); + } + + @ScalarFunction("test_value_block_position") + public static final class FunctionWithValueBlockAndPositionConvention + { + private static final AtomicBoolean hitBlockPositionBigint = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionDouble = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionSlice = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionBoolean = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionObject = new AtomicBoolean(); + + // generic implementations + // these will not work right now because MethodHandle is not properly adapted + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Object object) + { + return object; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType("E") ValueBlock block, @BlockIndex int position) + { + hitBlockPositionObject.set(true); + return readNativeValue(type, block, position); + } + + // specialized + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Slice slice) + { + return slice; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionSlice.set(true); + return type.getSlice(block, position); + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Boolean bool) + { + return bool; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionBoolean.set(true); + return type.getBoolean(block, position); + } + + // exact + + @SqlType(StandardTypes.BIGINT) + @SqlNullable + public static Long getLong(@SqlNullable @SqlType(StandardTypes.BIGINT) Long number) + { + return number; + } + + @SqlType(StandardTypes.BIGINT) + @SqlNullable + public static Long getBlockPosition(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionBigint.set(true); + return BIGINT.getLong(block, position); + } + + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@SqlNullable @SqlType(StandardTypes.DOUBLE) Double number) + { + return number; + } + + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionDouble.set(true); return DOUBLE.getDouble(block, position); diff --git a/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java b/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java index 13a9443c1277..ee754fa020ce 100644 --- a/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java +++ b/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java @@ -238,7 +238,7 @@ private static class FourFixedPartitionsPartitionFunction } @Override - public int getPartitionCount() + public int partitionCount() { return 4; } @@ -274,7 +274,7 @@ private static class ModuloPartitionFunction } @Override - public int getPartitionCount() + public int partitionCount() { return partitionCount; } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java index 62775fc986d4..e0b1ae1407cc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java @@ -22,6 +22,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; +import io.trino.spi.block.VariableWidthBlock; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SpecialForm; @@ -175,14 +176,20 @@ private static boolean filter(int position, Block discountBlock, Block shipDateB private static boolean lessThan(Block left, int leftPosition, Slice right) { + VariableWidthBlock leftBlock = (VariableWidthBlock) left.getUnderlyingValueBlock(); + Slice leftSlice = leftBlock.getRawSlice(); + int leftOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = left.getSliceLength(leftPosition); - return left.bytesCompare(leftPosition, 0, leftLength, right, 0, right.length()) < 0; + return leftSlice.compareTo(leftOffset, leftLength, right, 0, right.length()) < 0; } private static boolean greaterThanOrEqual(Block left, int leftPosition, Slice right) { + VariableWidthBlock leftBlock = (VariableWidthBlock) left.getUnderlyingValueBlock(); + Slice leftSlice = leftBlock.getRawSlice(); + int leftOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = left.getSliceLength(leftPosition); - return left.bytesCompare(leftPosition, 0, leftLength, right, 0, right.length()) >= 0; + return leftSlice.compareTo(leftOffset, leftLength, right, 0, right.length()) >= 0; } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java index 51c01c21caa9..77ee738f4f1d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java @@ -28,16 +28,21 @@ public class TestCompilerConfig public void testDefaults() { assertRecordedDefaults(recordDefaults(CompilerConfig.class) - .setExpressionCacheSize(10_000)); + .setExpressionCacheSize(10_000) + .setSpecializeAggregationLoops(true)); } @Test public void testExplicitPropertyMappings() { - Map properties = ImmutableMap.of("compiler.expression-cache-size", "52"); + Map properties = ImmutableMap.builder() + .put("compiler.expression-cache-size", "52") + .put("compiler.specialized-aggregation-loops", "false") + .buildOrThrow(); CompilerConfig expected = new CompilerConfig() - .setExpressionCacheSize(52); + .setExpressionCacheSize(52) + .setSpecializeAggregationLoops(false); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java index 608102463fbb..4583fb3792e7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -81,7 +81,7 @@ public void testSimpleDeletedRowMerge() // Show that the row to be deleted is rowId 0, e.g. ('Dave', 11, 'Devon') SqlRow rowIdRow = outputPage.getBlock(4).getObject(0, SqlRow.class); - assertThat(INTEGER.getInt(rowIdRow.getRawFieldBlock(1), rowIdRow.getRawIndex())).isEqualTo(0); + assertThat(BIGINT.getLong(rowIdRow.getRawFieldBlock(1), rowIdRow.getRawIndex())).isEqualTo(0); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java index 45b2fd1eff8b..f88bde2a1f88 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java @@ -120,7 +120,7 @@ public void testRewriteRightCorrelatedJoin() .matches( project( ImmutableMap.of( - "a", expression("if(b > a, a, null)"), + "a", expression("if(b > a, a, cast(null AS BIGINT))"), "b", expression("b")), join(Type.INNER, builder -> builder .left(values("a")) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java index 224403807f3d..f1892a5edc47 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java @@ -36,7 +36,6 @@ import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.functionCall; -import static io.trino.sql.planner.assertions.PlanMatchPattern.output; import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; @@ -56,21 +55,6 @@ public class TestEliminateSorts "lineitem", ImmutableMap.of(QUANTITY_ALIAS, "quantity")); - @Test - public void testEliminateSorts() - { - @Language("SQL") String sql = "SELECT quantity, row_number() OVER (ORDER BY quantity) FROM lineitem ORDER BY quantity"; - - PlanMatchPattern pattern = - output( - window(windowMatcherBuilder -> windowMatcherBuilder - .specification(windowSpec) - .addFunction(functionCall("row_number", Optional.empty(), ImmutableList.of())), - anyTree(LINEITEM_TABLESCAN_Q))); - - assertUnitPlan(sql, pattern); - } - @Test public void testNotEliminateSorts() { diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestCorrelatedAggregation.java b/core/trino-main/src/test/java/io/trino/sql/query/TestCorrelatedAggregation.java index f25dcb53af9b..c3dc4fae2a49 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestCorrelatedAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestCorrelatedAggregation.java @@ -482,4 +482,34 @@ public void testChecksum() "ON TRUE")) .matches("VALUES (1, null), (2, x'd0f70cebd131ec61')"); } + + @Test + public void testCorrelatedSubqueryWithGroupedAggregation() + { + assertThat(assertions.query("WITH" + + " t(k, v) AS (VALUES ('A', 1), ('B', NULL), ('C', 2), ('D', 3)), " + + " u(k, v) AS (VALUES (1, 10), (1, 20), (2, 30)) " + + "SELECT" + + " k," + + " (" + + " SELECT max(v) FROM u WHERE t.v = u.k GROUP BY k" + + " ) AS cols " + + "FROM t")) + .matches("VALUES ('A', 20), ('B', NULL), ('C', 30), ('D', NULL)"); + } + + @Test + public void testCorrelatedSubqueryWithGlobalAggregation() + { + assertThat(assertions.query("WITH" + + " t(k, v) AS (VALUES ('A', 1), ('B', NULL), ('C', 2), ('D', 3)), " + + " u(k, v) AS (VALUES (1, 10), (1, 20), (2, 30)) " + + "SELECT" + + " k," + + " (" + + " SELECT max(v) FROM u WHERE t.v = u.k" + + " ) AS cols " + + "FROM t")) + .matches("VALUES ('A', 20), ('B', NULL), ('C', 30), ('D', NULL)"); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java b/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java index e6a933517cf9..e3e31117fda2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestSubqueries.java @@ -46,7 +46,7 @@ import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; -import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; @@ -204,7 +204,7 @@ public void testCorrelatedSubqueriesWithTopN() "SELECT (SELECT t.a FROM (VALUES 1, 2, 3) t(a) WHERE t.a = t2.b ORDER BY a LIMIT 1) FROM (VALUES 1.0, 2.0) t2(b)", "VALUES 1, 2", output( - join(INNER, builder -> builder + join(LEFT, builder -> builder .equiCriteria("cast_b", "cast_a") .left( project( @@ -228,7 +228,7 @@ public void testCorrelatedSubqueriesWithTopN() "SELECT (SELECT t.a FROM (VALUES 1, 2, 3, 4, 5) t(a) WHERE t.a = t2.b * t2.c - 1 ORDER BY a LIMIT 1) FROM (VALUES (1, 2), (2, 3)) t2(b, c)", "VALUES 1, 5", output( - join(INNER, builder -> builder + join(LEFT, builder -> builder .equiCriteria("expr", "a") .left( project( diff --git a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java index 0ea4a710916d..19d8600f960d 100644 --- a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java +++ b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockEncodingSerde; import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.LongTimestampWithTimeZone; @@ -59,9 +60,10 @@ import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_LAST; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -89,7 +91,7 @@ public abstract class AbstractTestType private final BlockEncodingSerde blockEncodingSerde = new TestingBlockEncodingSerde(); private final Class objectValueType; - private final Block testBlock; + private final ValueBlock testBlock; protected final Type type; private final TypeOperators typeOperators; @@ -116,35 +118,35 @@ public abstract class AbstractTestType private final BlockPositionIsDistinctFrom distinctFromOperator; private final SortedMap expectedStackValues; private final SortedMap expectedObjectValues; - private final Block testBlockWithNulls; + private final ValueBlock testBlockWithNulls; - protected AbstractTestType(Type type, Class objectValueType, Block testBlock) + protected AbstractTestType(Type type, Class objectValueType, ValueBlock testBlock) { this(type, objectValueType, testBlock, testBlock); } - protected AbstractTestType(Type type, Class objectValueType, Block testBlock, Block expectedValues) + protected AbstractTestType(Type type, Class objectValueType, ValueBlock testBlock, ValueBlock expectedValues) { this.type = requireNonNull(type, "type is null"); typeOperators = new TypeOperators(); - readBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + readBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); writeBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, NEVER_NULL)); writeFlatToBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)); readFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); writeFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, NEVER_NULL)); - writeBlockToFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION)); + writeBlockToFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION)); blockTypeOperators = new BlockTypeOperators(typeOperators); if (type.isComparable()) { stackStackEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)); flatFlatEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, FLAT, FLAT)); - flatBlockPositionEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, FLAT, BLOCK_POSITION)); - blockPositionFlatEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, BLOCK_POSITION, FLAT)); + flatBlockPositionEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, FLAT, VALUE_BLOCK_POSITION)); + blockPositionFlatEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, VALUE_BLOCK_POSITION, FLAT)); flatHashCodeOperator = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); flatXxHash64Operator = typeOperators.getXxHash64Operator(type, simpleConvention(FAIL_ON_NULL, FLAT)); flatFlatDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - flatBlockPositionDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION)); - blockPositionFlatDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, FLAT)); + flatBlockPositionDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION)); + blockPositionFlatDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION, FLAT)); equalOperator = blockTypeOperators.getEqualOperator(type); hashCodeOperator = blockTypeOperators.getHashCodeOperator(type); @@ -176,7 +178,7 @@ protected AbstractTestType(Type type, Class objectValueType, Block testBlock, this.testBlockWithNulls = createAlternatingNullsBlock(testBlock); } - private Block createAlternatingNullsBlock(Block testBlock) + private ValueBlock createAlternatingNullsBlock(Block testBlock) { BlockBuilder nullsBlockBuilder = type.createBlockBuilder(null, testBlock.getPositionCount()); for (int position = 0; position < testBlock.getPositionCount(); position++) { @@ -202,7 +204,7 @@ else if (type.getJavaType() == Slice.class) { } nullsBlockBuilder.appendNull(); } - return nullsBlockBuilder.build(); + return nullsBlockBuilder.buildValueBlock(); } @Test @@ -337,7 +339,7 @@ else if (stackStackEqualOperator != null) { assertFalse((boolean) flatBlockPositionDistinctFromOperator.invokeExact(fixed, elementFixedOffset, variable, testBlock, i)); assertFalse((boolean) blockPositionFlatDistinctFromOperator.invokeExact(testBlock, i, fixed, elementFixedOffset, variable)); - Block nullValue = type.createBlockBuilder(null, 1).appendNull().build(); + ValueBlock nullValue = type.createBlockBuilder(null, 1).appendNull().buildValueBlock(); assertTrue((boolean) flatBlockPositionDistinctFromOperator.invokeExact(fixed, elementFixedOffset, variable, nullValue, 0)); assertTrue((boolean) blockPositionFlatDistinctFromOperator.invokeExact(nullValue, 0, fixed, elementFixedOffset, variable)); } @@ -349,7 +351,7 @@ protected Object getSampleValue() return requireNonNull(Iterables.get(expectedStackValues.values(), 0), "sample value is null"); } - protected void assertPositionEquals(Block block, int position, Object expectedStackValue, Object expectedObjectValue) + protected void assertPositionEquals(ValueBlock block, int position, Object expectedStackValue, Object expectedObjectValue) throws Throwable { long hash = 0; @@ -364,16 +366,16 @@ protected void assertPositionEquals(Block block, int position, Object expectedSt BlockBuilder blockBuilder = type.createBlockBuilder(null, 1); type.appendTo(block, position, blockBuilder); - assertPositionValue(blockBuilder.build(), 0, expectedStackValue, hash, expectedObjectValue); + assertPositionValue(blockBuilder.buildValueBlock(), 0, expectedStackValue, hash, expectedObjectValue); if (expectedStackValue != null) { blockBuilder = type.createBlockBuilder(null, 1); writeBlockMethod.invoke(expectedStackValue, blockBuilder); - assertPositionValue(blockBuilder.build(), 0, expectedStackValue, hash, expectedObjectValue); + assertPositionValue(blockBuilder.buildValueBlock(), 0, expectedStackValue, hash, expectedObjectValue); } } - private void assertPositionValue(Block block, int position, Object expectedStackValue, long expectedHash, Object expectedObjectValue) + private void assertPositionValue(ValueBlock block, int position, Object expectedStackValue, long expectedHash, Object expectedObjectValue) throws Throwable { assertEquals(block.isNull(position), expectedStackValue == null); @@ -643,7 +645,7 @@ else if (javaType == Slice.class) { else { type.writeObject(blockBuilder, value); } - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } /** @@ -727,7 +729,7 @@ else if (javaType == Slice.class) { else { type.writeObject(blockBuilder, value); } - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } private static SortedMap indexStackValues(Type type, Block block) diff --git a/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java b/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java index 5b6e95354165..7c04b7e78363 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import org.junit.jupiter.api.Test; @@ -39,7 +39,7 @@ public TestArrayOfMapOfBigintVarcharType() super(TYPE, List.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TYPE.createBlockBuilder(null, 4); TYPE.writeObject(blockBuilder, arrayBlockOf(TYPE.getElementType(), @@ -51,7 +51,7 @@ public static Block createTestBlock() TYPE.writeObject(blockBuilder, arrayBlockOf(TYPE.getElementType(), sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(100, "hundred")), sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(200, "two hundred")))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java index d486da1a824c..7f466ba8e8bc 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestBigintArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(BIGINT.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(BIGINT.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } BIGINT.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestBigintType.java b/core/trino-main/src/test/java/io/trino/type/TestBigintType.java index 91ca30931c58..c22a8800f2ad 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBigintType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBigintType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestBigintType() super(BIGINT, Long.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, 15); BIGINT.writeLong(blockBuilder, 1111); @@ -46,7 +46,7 @@ public static Block createTestBlock() BIGINT.writeLong(blockBuilder, 3333); BIGINT.writeLong(blockBuilder, 3333); BIGINT.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java index 1ba0f98ce4dd..88f279d6dd75 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestBigintVarcharMapType() super(mapType(BIGINT, VARCHAR), Map.class, createTestBlock(mapType(BIGINT, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java b/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java index c26e18547374..b267d02f8c9c 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.BooleanType; import org.junit.jupiter.api.Test; @@ -66,7 +67,7 @@ public void testBooleanBlockWithSingleNonNullValue() assertFalse(BooleanType.createBlockForSingleNonNullValue(false).mayHaveNull()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = BOOLEAN.createBlockBuilder(null, 15); BOOLEAN.writeBoolean(blockBuilder, true); @@ -80,7 +81,7 @@ public static Block createTestBlock() BOOLEAN.writeBoolean(blockBuilder, true); BOOLEAN.writeBoolean(blockBuilder, true); BOOLEAN.writeBoolean(blockBuilder, false); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java b/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java index e482515905a0..7874cea36276 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import org.junit.jupiter.api.Test; @@ -34,7 +34,7 @@ public TestBoundedVarcharType() super(createVarcharType(6), String.class, createTestBlock(createVarcharType(6))); } - private static Block createTestBlock(VarcharType type) + private static ValueBlock createTestBlock(VarcharType type) { BlockBuilder blockBuilder = type.createBlockBuilder(null, 15); type.writeString(blockBuilder, "apple"); @@ -48,7 +48,7 @@ private static Block createTestBlock(VarcharType type) type.writeString(blockBuilder, "cherry"); type.writeString(blockBuilder, "cherry"); type.writeString(blockBuilder, "date"); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestCharType.java b/core/trino-main/src/test/java/io/trino/type/TestCharType.java index 77b55e254241..333fa0086fb3 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestCharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestCharType.java @@ -18,6 +18,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.CharType; import io.trino.spi.type.Type; @@ -45,7 +46,7 @@ public TestCharType() super(CHAR_TYPE, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = CHAR_TYPE.createBlockBuilder(null, 15); CHAR_TYPE.writeString(blockBuilder, "apple"); @@ -59,7 +60,7 @@ public static Block createTestBlock() CHAR_TYPE.writeString(blockBuilder, "cherry"); CHAR_TYPE.writeString(blockBuilder, "cherry"); CHAR_TYPE.writeString(blockBuilder, "date"); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java index 01b283cd73aa..f2d81fe09119 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -35,14 +35,14 @@ public TestColorArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(COLOR.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(COLOR.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestColorType.java b/core/trino-main/src/test/java/io/trino/type/TestColorType.java index 6b4557f9487f..3d640f0d3f80 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestColorType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestColorType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.operator.scalar.ColorFunctions.rgb; @@ -54,7 +55,7 @@ public void testGetObjectValue() } } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = COLOR.createBlockBuilder(null, 15); COLOR.writeLong(blockBuilder, rgb(1, 1, 1)); @@ -68,7 +69,7 @@ public static Block createTestBlock() COLOR.writeLong(blockBuilder, rgb(3, 3, 3)); COLOR.writeLong(blockBuilder, rgb(3, 3, 3)); COLOR.writeLong(blockBuilder, rgb(4, 4, 4)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java b/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java index fa1ebfd0f75d..309179070ac6 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java +++ b/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java @@ -16,6 +16,8 @@ import io.trino.metadata.InternalFunctionBundle; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; import io.trino.spi.function.Convention; @@ -37,6 +39,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.IntegerType.INTEGER; import static org.assertj.core.api.Assertions.assertThat; @@ -55,6 +58,7 @@ public void init() assertions.addFunctions(InternalFunctionBundle.builder() .scalar(RegularConvention.class) .scalar(BlockPositionConvention.class) + .scalar(ValueBlockPositionConvention.class) .scalar(Add.class) .build()); @@ -88,6 +92,15 @@ public void testConventionDependencies() assertThat(assertions.function("block_position_convention", "ARRAY[56, 275, 36]")) .isEqualTo(367); + + assertThat(assertions.function("value_block_position_convention", "ARRAY[1, 2, 3]")) + .isEqualTo(6); + + assertThat(assertions.function("value_block_position_convention", "ARRAY[25, 0, 5]")) + .isEqualTo(30); + + assertThat(assertions.function("value_block_position_convention", "ARRAY[56, 275, 36]")) + .isEqualTo(367); } @ScalarFunction("regular_convention") @@ -138,6 +151,34 @@ public static long testBlockPositionConvention( } } + @ScalarFunction("value_block_position_convention") + public static final class ValueBlockPositionConvention + { + @SqlType(StandardTypes.INTEGER) + public static long testBlockPositionConvention( + @FunctionDependency( + name = "add", + argumentTypes = {StandardTypes.INTEGER, StandardTypes.INTEGER}, + convention = @Convention(arguments = {NEVER_NULL, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle function, + @SqlType("array(integer)") Block array) + { + ValueBlock arrayValues = array.getUnderlyingValueBlock(); + + long sum = 0; + for (int i = 0; i < array.getPositionCount(); i++) { + try { + sum = (long) function.invokeExact(sum, arrayValues, array.getUnderlyingValuePosition(i)); + } + catch (Throwable t) { + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, TrinoException.class); + throw new TrinoException(GENERIC_INTERNAL_ERROR, t); + } + } + return sum; + } + } + @ScalarFunction("add") public static final class Add { @@ -152,7 +193,7 @@ public static long add( @SqlType(StandardTypes.INTEGER) public static long addBlockPosition( @SqlType(StandardTypes.INTEGER) long first, - @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) Block block, + @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) IntArrayBlock block, @BlockIndex int position) { return Math.addExact((int) first, INTEGER.getInt(block, position)); diff --git a/core/trino-main/src/test/java/io/trino/type/TestDateType.java b/core/trino-main/src/test/java/io/trino/type/TestDateType.java index 8d27947777ff..9e3565cfcb86 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestDateType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestDateType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlDate; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -33,7 +33,7 @@ public TestDateType() super(DATE, SqlDate.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = DATE.createBlockBuilder(null, 15); DATE.writeLong(blockBuilder, 1111); @@ -47,7 +47,7 @@ public static Block createTestBlock() DATE.writeLong(blockBuilder, 3333); DATE.writeLong(blockBuilder, 3333); DATE.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java b/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java index 73c366d7e3d0..d13a3eee46c1 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.LongArrayBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import io.trino.type.BlockTypeOperators.BlockPositionXxHash64; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public TestDoubleType() super(DOUBLE, Double.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, 15); DOUBLE.writeDouble(blockBuilder, 11.11); @@ -48,7 +49,7 @@ public static Block createTestBlock() DOUBLE.writeDouble(blockBuilder, 33.33); DOUBLE.writeDouble(blockBuilder, 33.33); DOUBLE.writeDouble(blockBuilder, 44.44); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java index 5d7853f2aa44..40dda9308495 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestIntegerArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(INTEGER.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(INTEGER.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } INTEGER.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java b/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java index 798a449d58b5..2a76b84a3c3a 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestIntegerType() super(INTEGER, Integer.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, 15); INTEGER.writeLong(blockBuilder, 1111); @@ -46,7 +46,7 @@ public static Block createTestBlock() INTEGER.writeLong(blockBuilder, 3333); INTEGER.writeLong(blockBuilder, 3333); INTEGER.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java index 69287415be6f..5349a0475f2a 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestIntegerVarcharMapType() super(mapType(INTEGER, VARCHAR), Map.class, createTestBlock(mapType(INTEGER, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(INTEGER, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(INTEGER, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(INTEGER, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java b/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java index eddb22084a4b..e967ec08bff4 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.type.IntervalDayTimeType.INTERVAL_DAY_TIME; @@ -28,7 +28,7 @@ public TestIntervalDayTimeType() super(INTERVAL_DAY_TIME, SqlIntervalDayTime.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = INTERVAL_DAY_TIME.createBlockBuilder(null, 15); INTERVAL_DAY_TIME.writeLong(blockBuilder, 1111); @@ -42,7 +42,7 @@ public static Block createTestBlock() INTERVAL_DAY_TIME.writeLong(blockBuilder, 3333); INTERVAL_DAY_TIME.writeLong(blockBuilder, 3333); INTERVAL_DAY_TIME.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java b/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java index f16d0d577828..108ad544ead9 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; @@ -28,7 +28,7 @@ public TestIntervalYearMonthType() super(INTERVAL_YEAR_MONTH, SqlIntervalYearMonth.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = INTERVAL_YEAR_MONTH.createBlockBuilder(null, 15); INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 1111); @@ -42,7 +42,7 @@ public static Block createTestBlock() INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 3333); INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 3333); INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java b/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java index a1f057f3c415..6c4a4c7d42ee 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java @@ -16,8 +16,8 @@ import com.google.common.net.InetAddresses; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static com.google.common.base.Preconditions.checkState; @@ -33,7 +33,7 @@ public TestIpAddressType() super(IPADDRESS, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = IPADDRESS.createBlockBuilder(null, 1); IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8320")); @@ -46,7 +46,7 @@ public static Block createTestBlock() IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8327")); IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8328")); IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8329")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestJsonType.java b/core/trino-main/src/test/java/io/trino/type/TestJsonType.java index 26bd09677460..20f14a17b824 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestJsonType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestJsonType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.type.JsonType.JSON; @@ -31,12 +31,12 @@ public TestJsonType() super(JSON, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = JSON.createBlockBuilder(null, 1); Slice slice = Slices.utf8Slice("{\"x\":1, \"y\":2}"); JSON.writeSlice(blockBuilder, slice); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java b/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java index ec01f82524b7..1dd5ebb9a895 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; @@ -36,7 +36,7 @@ public TestLongDecimalType() super(LONG_DECIMAL_TYPE, SqlDecimal.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = LONG_DECIMAL_TYPE.createBlockBuilder(null, 15); writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("-12345678901234567890.1234567890")); @@ -50,7 +50,7 @@ public static Block createTestBlock() writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("32345678901234567890.1234567890")); writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("32345678901234567890.1234567890")); writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("42345678901234567890.1234567890")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java index d1aeb3cef4bb..6e194fafc0fc 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableList; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.SqlTimestamp; import io.trino.spi.type.Type.Range; @@ -36,7 +36,7 @@ public TestLongTimestampType() super(TIMESTAMP_NANOS, SqlTimestamp.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_NANOS.createBlockBuilder(null, 15); TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(1111_123, 123_000)); @@ -50,7 +50,7 @@ public static Block createTestBlock() TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(3333_123, 123_000)); TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(3333_123, 123_000)); TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(4444_123, 123_000)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java index 9016e994ada4..ce610e7fcf63 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.SqlTimestampWithTimeZone; import io.trino.spi.type.Type; @@ -40,7 +40,7 @@ public TestLongTimestampWithTimeZoneType() super(TIMESTAMP_TZ_MICROS, SqlTimestampWithTimeZone.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_TZ_MICROS.createBlockBuilder(null, 15); TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(1111, 0, getTimeZoneKeyForOffset(0))); @@ -54,7 +54,7 @@ public static Block createTestBlock() TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(3333, 0, getTimeZoneKeyForOffset(8))); TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(3333, 0, getTimeZoneKeyForOffset(9))); TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(4444, 0, getTimeZoneKeyForOffset(10))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestRealType.java b/core/trino-main/src/test/java/io/trino/type/TestRealType.java index 1349b5697130..7fbfa28d67c1 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestRealType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestRealType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.IntArrayBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public TestRealType() super(REAL, Float.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = REAL.createBlockBuilder(null, 30); REAL.writeLong(blockBuilder, floatToRawIntBits(11.11F)); @@ -48,7 +49,7 @@ public static Block createTestBlock() REAL.writeLong(blockBuilder, floatToRawIntBits(33.33F)); REAL.writeLong(blockBuilder, floatToRawIntBits(33.33F)); REAL.writeLong(blockBuilder, floatToRawIntBits(44.44F)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java b/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java index 2807bdc50ed1..712e83d682fd 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.DecimalType; import io.trino.spi.type.SqlDecimal; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestShortDecimalType() super(SHORT_DECIMAL_TYPE, SqlDecimal.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = SHORT_DECIMAL_TYPE.createBlockBuilder(null, 15); SHORT_DECIMAL_TYPE.writeLong(blockBuilder, -1234); @@ -46,7 +46,7 @@ public static Block createTestBlock() SHORT_DECIMAL_TYPE.writeLong(blockBuilder, 3321); SHORT_DECIMAL_TYPE.writeLong(blockBuilder, 3321); SHORT_DECIMAL_TYPE.writeLong(blockBuilder, 4321); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java index f5b113ab4d7f..a80089c14bca 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTimestamp; import io.trino.spi.type.Type; import io.trino.spi.type.Type.Range; @@ -39,7 +39,7 @@ public TestShortTimestampType() super(TIMESTAMP_MILLIS, SqlTimestamp.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_MILLIS.createBlockBuilder(null, 15); TIMESTAMP_MILLIS.writeLong(blockBuilder, 1111_000); @@ -53,7 +53,7 @@ public static Block createTestBlock() TIMESTAMP_MILLIS.writeLong(blockBuilder, 3333_000); TIMESTAMP_MILLIS.writeLong(blockBuilder, 3333_000); TIMESTAMP_MILLIS.writeLong(blockBuilder, 4444_000); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java index 848a60321307..2fbf03b4963e 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTimestampWithTimeZone; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestShortTimestampWithTimeZoneType() super(TIMESTAMP_TZ_MILLIS, SqlTimestampWithTimeZone.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_TZ_MILLIS.createBlockBuilder(null, 15); TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(1111, getTimeZoneKeyForOffset(0))); @@ -46,7 +46,7 @@ public static Block createTestBlock() TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(3333, getTimeZoneKeyForOffset(8))); TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(3333, getTimeZoneKeyForOffset(9))); TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(4444, getTimeZoneKeyForOffset(10))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java b/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java index 4051d8275b1b..a31cadf58b20 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java @@ -14,9 +14,9 @@ package io.trino.type; import com.google.common.collect.ImmutableList; -import io.trino.spi.block.Block; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.RowType; import org.junit.jupiter.api.Test; @@ -41,7 +41,7 @@ public TestSimpleRowType() super(TYPE, List.class, createTestBlock()); } - private static Block createTestBlock() + private static ValueBlock createTestBlock() { RowBlockBuilder blockBuilder = TYPE.createBlockBuilder(null, 3); @@ -60,7 +60,7 @@ private static Block createTestBlock() VARCHAR.writeSlice(fieldBuilders.get(1), utf8Slice("dog")); }); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java index 40f97a4426f2..7874f8defe9a 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestSmallintArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(SMALLINT.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(SMALLINT.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } SMALLINT.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java b/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java index bfbf5f76d737..46a163aa9d2b 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestSmallintType() super(SMALLINT, Short.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = SMALLINT.createBlockBuilder(null, 15); SMALLINT.writeLong(blockBuilder, 1111); @@ -46,7 +46,7 @@ public static Block createTestBlock() SMALLINT.writeLong(blockBuilder, 3333); SMALLINT.writeLong(blockBuilder, 3333); SMALLINT.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java index 813611cbef70..15eead1a0eb6 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestSmallintVarcharMapType() super(mapType(SMALLINT, VARCHAR), Map.class, createTestBlock(mapType(SMALLINT, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(SMALLINT, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(SMALLINT, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(SMALLINT, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTimeType.java b/core/trino-main/src/test/java/io/trino/type/TestTimeType.java index 99bc510a2984..2454aa80d464 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTimeType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTimeType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTime; import org.junit.jupiter.api.Test; @@ -29,7 +29,7 @@ public TestTimeType() super(TIME_MILLIS, SqlTime.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIME_MILLIS.createBlockBuilder(null, 15); TIME_MILLIS.writeLong(blockBuilder, 1_111_000_000_000L); @@ -43,7 +43,7 @@ public static Block createTestBlock() TIME_MILLIS.writeLong(blockBuilder, 3_333_000_000_000L); TIME_MILLIS.writeLong(blockBuilder, 3_333_000_000_000L); TIME_MILLIS.writeLong(blockBuilder, 4_444_000_000_000L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java b/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java index 35657579b923..df18e4c47593 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTimeWithTimeZone; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestTimeWithTimeZoneType() super(TIME_TZ_MILLIS, SqlTimeWithTimeZone.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIME_TZ_MILLIS.createBlockBuilder(null, 15); TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(1_111_000_000L, 0)); @@ -46,7 +46,7 @@ public static Block createTestBlock() TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(3_333_000_000L, 8)); TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(3_333_000_000L, 9)); TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(4_444_000_000L, 10)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java index e78bb757d25d..327622dd2c28 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestTinyintArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(TINYINT.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(TINYINT.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 100, 110, 127)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } TINYINT.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java b/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java index a4d3b0c75cab..c4987a648156 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestTinyintType() super(TINYINT, Byte.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TINYINT.createBlockBuilder(null, 15); TINYINT.writeLong(blockBuilder, 111); @@ -46,7 +46,7 @@ public static Block createTestBlock() TINYINT.writeLong(blockBuilder, 33); TINYINT.writeLong(blockBuilder, 33); TINYINT.writeLong(blockBuilder, 44); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java index 4f34ce2ec4d9..522bc24c44d8 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestTinyintVarcharMapType() super(mapType(TINYINT, VARCHAR), Map.class, createTestBlock(mapType(TINYINT, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(TINYINT, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(TINYINT, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(TINYINT, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java b/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java index 36911a64fd90..4fb2e02eab98 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -30,7 +30,7 @@ public TestUnboundedVarcharType() super(VARCHAR, String.class, createTestBlock()); } - private static Block createTestBlock() + private static ValueBlock createTestBlock() { BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 15); VARCHAR.writeString(blockBuilder, "apple"); @@ -44,7 +44,7 @@ private static Block createTestBlock() VARCHAR.writeString(blockBuilder, "cherry"); VARCHAR.writeString(blockBuilder, "cherry"); VARCHAR.writeString(blockBuilder, "date"); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java b/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java index ff22b3c1d43e..6a2a0ce364ac 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java @@ -29,7 +29,7 @@ public TestUnknownType() .appendNull() .appendNull() .appendNull() - .build()); + .buildValueBlock()); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestUuidType.java b/core/trino-main/src/test/java/io/trino/type/TestUuidType.java index a83ef738ba16..c10478f35c53 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestUuidType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestUuidType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.TypeOperators; import org.junit.jupiter.api.Test; @@ -44,14 +44,14 @@ public TestUuidType() super(UUID, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = UUID.createBlockBuilder(null, 1); for (int i = 0; i < 10; i++) { String uuid = "6b5f5b65-67e4-43b0-8ee3-586cd49f58a" + i; UUID.writeSlice(blockBuilder, castFromVarcharToUuid(utf8Slice(uuid))); } - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java b/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java index c9468a5421db..4b3e2c92dcd2 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlVarbinary; import org.junit.jupiter.api.Test; @@ -31,7 +31,7 @@ public TestVarbinaryType() super(VARBINARY, SqlVarbinary.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, 15); VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("apple")); @@ -45,7 +45,7 @@ public static Block createTestBlock() VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("cherry")); VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("cherry")); VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("date")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java index 2c403c4ec5f9..0015ab005c09 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -35,13 +36,13 @@ public TestVarcharArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(VARCHAR.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(VARCHAR.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(VARCHAR, "1", "2")); arrayType.writeObject(blockBuilder, arrayBlockOf(VARCHAR, "the", "quick", "brown", "fox")); arrayType.writeObject(blockBuilder, arrayBlockOf(VARCHAR, "one-two-three-four-five", "123456789012345", "the quick brown fox", "hello-world-hello-world-hello-world")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java index dc8858afb405..42a666cda7e2 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -35,13 +35,13 @@ public TestVarcharVarcharMapType() super(mapType(VARCHAR, VARCHAR), Map.class, createTestBlock(mapType(VARCHAR, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(VARCHAR, VARCHAR, ImmutableMap.of("hi", "there"))); mapType.writeObject(blockBuilder, sqlMapOf(VARCHAR, VARCHAR, ImmutableMap.of("one", "1", "hello", "world"))); mapType.writeObject(blockBuilder, sqlMapOf(VARCHAR, VARCHAR, ImmutableMap.of("one-two-three-four-five", "123456789012345", "the quick brown fox", "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-parser/pom.xml b/core/trino-parser/pom.xml index 7b09350efcb3..9414b1114be2 100644 --- a/core/trino-parser/pom.xml +++ b/core/trino-parser/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java index 19fe8278990d..b8b86d57f99c 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java +++ b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java @@ -25,7 +25,6 @@ import io.trino.sql.tree.Identifier; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.Node; -import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.QualifiedName; @@ -268,25 +267,6 @@ public static Query singleValueQuery(String columnName, boolean value) aliased(values, "t", ImmutableList.of(columnName))); } - // TODO pass column types - public static Query emptyQuery(List columns) - { - Select select = selectList(columns.stream() - .map(column -> new SingleColumn(new NullLiteral(), QueryUtil.identifier(column))) - .toArray(SelectItem[]::new)); - Optional where = Optional.of(FALSE_LITERAL); - return query(new QuerySpecification( - select, - Optional.empty(), - where, - Optional.empty(), - Optional.empty(), - ImmutableList.of(), - Optional.empty(), - Optional.empty(), - Optional.empty())); - } - public static Query query(QueryBody body) { return new Query( diff --git a/core/trino-server-main/pom.xml b/core/trino-server-main/pom.xml index cb5bd9e870e2..00147e4307dd 100644 --- a/core/trino-server-main/pom.xml +++ b/core/trino-server-main/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-server-rpm/pom.xml b/core/trino-server-rpm/pom.xml index 68917f402522..7324022ceed6 100644 --- a/core/trino-server-rpm/pom.xml +++ b/core/trino-server-rpm/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-server/pom.xml b/core/trino-server/pom.xml index 429556d2f918..ce844e8b6999 100644 --- a/core/trino-server/pom.xml +++ b/core/trino-server/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 7eb925b18135..365627801ea9 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -464,6 +464,741 @@ class io.trino.spi.block.VariableWidthBlock class io.trino.spi.block.VariableWidthBlock + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::copyPositions(int[], int, int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::copyRegion(int, int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::getLoadedBlock() + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getLoadedBlock() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::getRegion(int, int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::getSingleValueBlock(int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::copyRegion(int, int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::getRegion(int, int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::copyPositions(int[], int, int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::copyRegion(int, int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::copyWithAppendedNull() + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::getRegion(int, int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::getSingleValueBlock(int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::copyRegion(int, int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::getRegion(int, int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::copyRegion(int, int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::getRegion(int, int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::copyRegion(int, int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::getRegion(int, int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::copyPositions(int[], int, int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::copyRegion(int, int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::copyWithAppendedNull() + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::getRegion(int, int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::getSingleValueBlock(int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::copyPositions(int[], int, int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::copyRegion(int, int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::copyWithAppendedNull() + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::getRegion(int, int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::getSingleValueBlock(int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::copyRegion(int, int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::copyWithAppendedNull() + ADD YOUR EXPLANATION FOR THE NECESSITY OF THIS CHANGE + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::getRegion(int, int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::copyPositions(int[], int, int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::copyRegion(int, int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::copyWithAppendedNull() + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::getRegion(int, int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractVariableWidthBlock::getSingleValueBlock(int) @ io.trino.spi.block.VariableWidthBlock + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::getSingleValueBlock(int) + + + java.method.addedToInterface + method io.trino.spi.block.ValueBlock io.trino.spi.block.Block::getUnderlyingValueBlock() + + + java.method.addedToInterface + method int io.trino.spi.block.Block::getUnderlyingValuePosition(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RunLengthEncodedBlock::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.RunLengthEncodedBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RunLengthEncodedBlock::getValue() + method io.trino.spi.block.ValueBlock io.trino.spi.block.RunLengthEncodedBlock::getValue() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Block::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.Block::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.DictionaryBlock::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.DictionaryBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LazyBlock::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.LazyBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.DictionaryBlock::getDictionary() + method io.trino.spi.block.ValueBlock io.trino.spi.block.DictionaryBlock::getDictionary() + + + java.method.addedToInterface + method io.trino.spi.block.ValueBlock io.trino.spi.block.BlockBuilder::buildValueBlock() + + + java.method.numberOfParametersChanged + method void io.trino.spi.type.AbstractType::<init>(io.trino.spi.type.TypeSignature, java.lang.Class<?>) + method void io.trino.spi.type.AbstractType::<init>(io.trino.spi.type.TypeSignature, java.lang.Class<?>, java.lang.Class<? extends io.trino.spi.block.ValueBlock>) + + + java.method.numberOfParametersChanged + method void io.trino.spi.type.TimeWithTimeZoneType::<init>(int, java.lang.Class<?>) + method void io.trino.spi.type.TimeWithTimeZoneType::<init>(int, java.lang.Class<?>, java.lang.Class<? extends io.trino.spi.block.ValueBlock>) + + + java.method.addedToInterface + method java.lang.Class<? extends io.trino.spi.block.ValueBlock> io.trino.spi.type.Type::getValueBlockType() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.type.TypeUtils::writeNativeValue(io.trino.spi.type.Type, java.lang.Object) + method io.trino.spi.block.ValueBlock io.trino.spi.type.TypeUtils::writeNativeValue(io.trino.spi.type.Type, java.lang.Object) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::fromElementBlock(int, java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::fromElementBlock(int, java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::fromFieldBlocks(int, java.util.Optional<boolean[]>, io.trino.spi.block.Block[]) + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::fromFieldBlocks(int, java.util.Optional<boolean[]>, io.trino.spi.block.Block[]) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.type.MapType::createBlockFromKeyValue(java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block, io.trino.spi.block.Block) + method io.trino.spi.block.MapBlock io.trino.spi.type.MapType::createBlockFromKeyValue(java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block, io.trino.spi.block.Block) + + + java.method.visibilityIncreased + method int[] io.trino.spi.block.DictionaryBlock::getRawIds() + method int[] io.trino.spi.block.DictionaryBlock::getRawIds() + package + public + + + java.method.visibilityIncreased + method int io.trino.spi.block.DictionaryBlock::getRawIdsOffset() + method int io.trino.spi.block.DictionaryBlock::getRawIdsOffset() + package + public + + + java.method.removed + method int io.trino.spi.block.Block::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.Block::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.Block::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.Block::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.Block::hash(int, int, int) + + + java.method.removed + method int io.trino.spi.block.DictionaryBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.DictionaryBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.DictionaryBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.DictionaryBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.DictionaryBlock::hash(int, int, int) + + + java.method.removed + method int io.trino.spi.block.LazyBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.LazyBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.LazyBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.LazyBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.LazyBlock::hash(int, int, int) + + + java.method.removed + method int io.trino.spi.block.RunLengthEncodedBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.RunLengthEncodedBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + ADD YOUR EXPLANATION FOR THE NECESSITY OF THIS CHANGE + + + java.method.removed + method int io.trino.spi.block.RunLengthEncodedBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.RunLengthEncodedBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.RunLengthEncodedBlock::hash(int, int, int) + + + java.method.removed + method void io.trino.spi.block.Block::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.removed + method void io.trino.spi.block.DictionaryBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.removed + method void io.trino.spi.block.LazyBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.removed + method void io.trino.spi.block.RunLengthEncodedBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12BlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12BlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.nowStatic + method void io.trino.spi.type.AbstractIntType::checkValueValid(long) + method void io.trino.spi.type.AbstractIntType::checkValueValid(long) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::copyRegion(int, int) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::getRegion(int, int) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::copyPositions(int[], int, int) + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::copyRegion(int, int) + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::getRegion(int, int) + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::getSingleValueBlock(int) + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::copyPositions(int[], int, int) + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::copyRegion(int, int) + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::getRegion(int, int) + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::getSingleValueBlock(int) + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::getSingleValueBlock(int) + + + java.method.removed + method int io.trino.spi.block.VariableWidthBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.VariableWidthBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.VariableWidthBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.VariableWidthBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::getSingleValueBlock(int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::getSingleValueBlock(int) + + + java.method.removed + method long io.trino.spi.block.VariableWidthBlock::hash(int, int, int) + + + java.method.removed + method void io.trino.spi.block.VariableWidthBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.class.nowFinal + class io.trino.spi.block.ArrayBlock + class io.trino.spi.block.ArrayBlock + + + java.class.nowFinal + class io.trino.spi.block.ByteArrayBlock + class io.trino.spi.block.ByteArrayBlock + + + java.class.nowFinal + class io.trino.spi.block.DictionaryBlock + class io.trino.spi.block.DictionaryBlock + + + java.class.nowFinal + class io.trino.spi.block.Fixed12Block + class io.trino.spi.block.Fixed12Block + + + java.class.nowFinal + class io.trino.spi.block.Int128ArrayBlock + class io.trino.spi.block.Int128ArrayBlock + + + java.class.nowFinal + class io.trino.spi.block.IntArrayBlock + class io.trino.spi.block.IntArrayBlock + + + java.class.nowFinal + class io.trino.spi.block.LazyBlock + class io.trino.spi.block.LazyBlock + + + java.class.nowFinal + class io.trino.spi.block.LongArrayBlock + class io.trino.spi.block.LongArrayBlock + + + java.class.nowFinal + class io.trino.spi.block.MapBlock + class io.trino.spi.block.MapBlock + + + java.class.nowFinal + class io.trino.spi.block.RowBlock + class io.trino.spi.block.RowBlock + + + java.class.nowFinal + class io.trino.spi.block.RunLengthEncodedBlock + class io.trino.spi.block.RunLengthEncodedBlock + + + java.class.nowFinal + class io.trino.spi.block.ShortArrayBlock + class io.trino.spi.block.ShortArrayBlock + + + java.class.nowFinal + class io.trino.spi.block.VariableWidthBlock + class io.trino.spi.block.VariableWidthBlock + + + java.method.visibilityReduced + method int io.trino.spi.block.ArrayBlock::getOffsetBase() + method int io.trino.spi.block.ArrayBlock::getOffsetBase() + protected + package + + + java.method.visibilityReduced + method int[] io.trino.spi.block.ArrayBlock::getOffsets() + method int[] io.trino.spi.block.ArrayBlock::getOffsets() + protected + package + + + java.method.visibilityReduced + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::getRawElementBlock() + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::getRawElementBlock() + protected + package + + + java.method.visibilityReduced + method void io.trino.spi.block.MapBlock::ensureHashTableLoaded() + method void io.trino.spi.block.MapBlock::ensureHashTableLoaded() + protected + package + + + java.method.visibilityReduced + method io.trino.spi.block.MapHashTables io.trino.spi.block.MapBlock::getHashTables() + method io.trino.spi.block.MapHashTables io.trino.spi.block.MapBlock::getHashTables() + protected + package + + + java.method.visibilityReduced + method io.trino.spi.type.MapType io.trino.spi.block.MapBlock::getMapType() + method io.trino.spi.type.MapType io.trino.spi.block.MapBlock::getMapType() + protected + package + + + java.method.visibilityReduced + method int io.trino.spi.block.MapBlock::getOffsetBase() + method int io.trino.spi.block.MapBlock::getOffsetBase() + protected + package + + + java.method.visibilityReduced + method int[] io.trino.spi.block.MapBlock::getOffsets() + method int[] io.trino.spi.block.MapBlock::getOffsets() + protected + package + + + java.method.visibilityReduced + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::getRawKeyBlock() + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::getRawKeyBlock() + protected + package + + + java.method.visibilityReduced + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::getRawValueBlock() + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::getRawValueBlock() + protected + package + + + java.method.visibilityReduced + method int[] io.trino.spi.block.RowBlock::getFieldBlockOffsets() + method int[] io.trino.spi.block.RowBlock::getFieldBlockOffsets() + protected + package + + + java.method.visibilityReduced + method int io.trino.spi.block.RowBlock::getOffsetBase() + method int io.trino.spi.block.RowBlock::getOffsetBase() + protected + package + + + java.method.visibilityReduced + method java.util.List<io.trino.spi.block.Block> io.trino.spi.block.RowBlock::getRawFieldBlocks() + method java.util.List<io.trino.spi.block.Block> io.trino.spi.block.RowBlock::getRawFieldBlocks() + protected + package + + + java.method.visibilityReduced + method int io.trino.spi.block.VariableWidthBlock::getPositionOffset(int) + method int io.trino.spi.block.VariableWidthBlock::getPositionOffset(int) + protected + package + diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java index 16d2eda205a0..9aad1d9da67e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java @@ -36,8 +36,8 @@ import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; -public class ArrayBlock - implements Block +public final class ArrayBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(ArrayBlock.class); @@ -54,7 +54,7 @@ public class ArrayBlock * Create an array block directly from columnar nulls, values, and offsets into the values. * A null array must have no entries. */ - public static Block fromElementBlock(int positionCount, Optional valueIsNullOptional, int[] arrayOffset, Block values) + public static ArrayBlock fromElementBlock(int positionCount, Optional valueIsNullOptional, int[] arrayOffset, Block values) { boolean[] valueIsNull = valueIsNullOptional.orElse(null); validateConstructorArguments(0, positionCount, valueIsNull, arrayOffset, values); @@ -73,7 +73,7 @@ public static Block fromElementBlock(int positionCount, Optional valu } /** - * Create an array block directly without per element validations. + * Create an array block directly without per-element validations. */ static ArrayBlock createArrayBlockInternal(int arrayOffset, int positionCount, @Nullable boolean[] valueIsNull, int[] offsets, Block values) { @@ -167,23 +167,23 @@ public void retainedBytesForEachPart(ObjLongConsumer consumer) consumer.accept(this, INSTANCE_SIZE); } - protected Block getRawElementBlock() + Block getRawElementBlock() { return values; } - protected int[] getOffsets() + int[] getOffsets() { return offsets; } - protected int getOffsetBase() + int getOffsetBase() { return arrayOffset; } @Override - public final List getChildren() + public List getChildren() { return singletonList(values); } @@ -216,7 +216,7 @@ public boolean isLoaded() } @Override - public Block getLoadedBlock() + public ArrayBlock getLoadedBlock() { Block loadedValuesBlock = values.getLoadedBlock(); @@ -232,7 +232,7 @@ public Block getLoadedBlock() } @Override - public Block copyWithAppendedNull() + public ArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, getPositionCount()); int[] newOffsets = copyOffsetsAndAppendNull(offsets, arrayOffset, getPositionCount()); @@ -246,7 +246,7 @@ public Block copyWithAppendedNull() } @Override - public Block copyPositions(int[] positions, int offset, int length) + public ArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -278,7 +278,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int position, int length) + public ArrayBlock getRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -310,7 +310,7 @@ public long getRegionSizeInBytes(int position, int length) } @Override - public final long getPositionsSizeInBytes(boolean[] positions, int selectedArrayPositions) + public long getPositionsSizeInBytes(boolean[] positions, int selectedArrayPositions) { int positionCount = getPositionCount(); checkValidPositions(positions, positionCount); @@ -343,7 +343,7 @@ else if (rawElementBlock instanceof RunLengthEncodedBlock) { } @Override - public Block copyRegion(int position, int length) + public ArrayBlock copyRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -369,15 +369,19 @@ public T getObject(int position, Class clazz) if (clazz != Block.class) { throw new IllegalArgumentException("clazz must be Block.class"); } - checkReadablePosition(this, position); + return clazz.cast(getArray(position)); + } + public Block getArray(int position) + { + checkReadablePosition(this, position); int startValueOffset = offsets[position + arrayOffset]; int endValueOffset = offsets[position + 1 + arrayOffset]; - return clazz.cast(values.getRegion(startValueOffset, endValueOffset - startValueOffset)); + return values.getRegion(startValueOffset, endValueOffset - startValueOffset); } @Override - public Block getSingleValueBlock(int position) + public ArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); @@ -421,6 +425,12 @@ public boolean isNull(int position) return valueIsNull != null && valueIsNull[position + arrayOffset]; } + @Override + public ArrayBlock getUnderlyingValueBlock() + { + return this; + } + public T apply(ArrayBlockFunction function, int position) { checkReadablePosition(this, position); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java index c23ae4c325a4..df28648003e7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java @@ -176,6 +176,15 @@ public Block build() if (!hasNonNullRow) { return nullRle(positionCount); } + return buildValueBlock(); + } + + @Override + public ValueBlock buildValueBlock() + { + if (currentEntryOpened) { + throw new IllegalStateException("Current entry must be closed before the block can be built"); + } return createArrayBlockInternal(0, positionCount, hasNullValue ? valueIsNull : null, offsets, values.build()); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java index bf3a3a9d26c2..ad01d4128132 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java @@ -50,11 +50,11 @@ public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceO for (int position = 0; position < positionCount + 1; position++) { sliceOutput.writeInt(offsets[offsetBase + position] - valuesStartOffset); } - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, arrayBlock); } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public ArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { Block values = blockEncodingSerde.readBlock(sliceInput); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java index 3281bb7e6727..5cf302fd2d44 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import java.util.Collections; import java.util.List; @@ -22,8 +21,10 @@ import java.util.function.ObjLongConsumer; import static io.trino.spi.block.BlockUtil.checkArrayRange; +import static io.trino.spi.block.DictionaryId.randomDictionaryId; -public interface Block +public sealed interface Block + permits DictionaryBlock, RunLengthEncodedBlock, LazyBlock, ValueBlock { /** * Gets the length of the value at the {@code position}. @@ -74,14 +75,6 @@ default Slice getSlice(int position, int offset, int length) throw new UnsupportedOperationException(getClass().getName()); } - /** - * Writes a slice at {@code offset} in the value at {@code position} into the {@code output} slice output. - */ - default void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - throw new UnsupportedOperationException(getClass().getName()); - } - /** * Gets an object in the value at {@code position}. */ @@ -90,58 +83,6 @@ default T getObject(int position, Class clazz) throw new UnsupportedOperationException(getClass().getName()); } - /** - * Is the byte sequences at {@code offset} in the value at {@code position} equal - * to the byte sequence at {@code otherOffset} in {@code otherSlice}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Compares the byte sequences at {@code offset} in the value at {@code position} - * to the byte sequence at {@code otherOffset} in {@code otherSlice}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Is the byte sequences at {@code offset} in the value at {@code position} equal - * to the byte sequence at {@code otherOffset} in the value at {@code otherPosition} - * in {@code otherBlock}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Calculates the hash code the byte sequences at {@code offset} in the - * value at {@code position}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default long hash(int position, int offset, int length) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Compares the byte sequences at {@code offset} in the value at {@code position} - * to the byte sequence at {@code otherOffset} in the value at {@code otherPosition} - * in {@code otherBlock}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - throw new UnsupportedOperationException(getClass().getName()); - } - /** * Gets the value at the specified position as a single element block. The method * must copy the data into a new block. @@ -151,7 +92,7 @@ default int compareTo(int leftPosition, int leftOffset, int leftLength, Block ri * * @throws IllegalArgumentException if this position is not valid */ - Block getSingleValueBlock(int position); + ValueBlock getSingleValueBlock(int position); /** * Returns the number of positions in this block. @@ -243,7 +184,7 @@ default Block getPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - return new DictionaryBlock(offset, length, this, positions); + return DictionaryBlock.createInternal(offset, length, this, positions, randomDictionaryId()); } /** @@ -334,4 +275,14 @@ default List getChildren() * i.e. not on in-progress block builders. */ Block copyWithAppendedNull(); + + /** + * Returns the underlying value block underlying this block. + */ + ValueBlock getUnderlyingValueBlock(); + + /** + * Returns the position in the underlying value block corresponding to the specified position in this block. + */ + int getUnderlyingValuePosition(int position); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java index 79f78dca5634..7d458991497e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java @@ -43,9 +43,15 @@ public interface BlockBuilder /** * Builds the block. This method can be called multiple times. + * The return value may be a block such as RLE to allow for optimizations when all block values are the same. */ Block build(); + /** + * Builds a ValueBlock. This method can be called multiple times. + */ + ValueBlock buildValueBlock(); + /** * Creates a new block builder of the same type based on the current usage statistics of this block builder. */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java index da643b1aa370..744c6753445c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java @@ -30,8 +30,8 @@ import static io.trino.spi.block.BlockUtil.copyIsNullAndAppendNull; import static io.trino.spi.block.BlockUtil.ensureCapacity; -public class ByteArrayBlock - implements Block +public final class ByteArrayBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(ByteArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Byte.BYTES + Byte.BYTES; @@ -128,10 +128,15 @@ public int getPositionCount() @Override public byte getByte(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getByte(position); + } + + public byte getByte(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -149,7 +154,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public ByteArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new ByteArrayBlock( @@ -160,7 +165,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public ByteArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -181,7 +186,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public ByteArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -189,7 +194,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public ByteArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -210,7 +215,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public ByteArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); byte[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); @@ -218,6 +223,12 @@ public Block copyWithAppendedNull() return new ByteArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public ByteArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java index 0d5592813c5e..559ead304ead 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java @@ -13,8 +13,6 @@ */ package io.trino.spi.block; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import jakarta.annotation.Nullable; import java.util.Arrays; @@ -91,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public ByteArrayBlock buildValueBlock() + { return new ByteArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -150,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - Slice getValuesSlice() - { - return Slices.wrappedBuffer(values, 0, positionCount); - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java index 0fc86d4549d1..17f346f4e440 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java @@ -13,7 +13,6 @@ */ package io.trino.spi.block; -import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; @@ -37,20 +36,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + ByteArrayBlock byteArrayBlock = (ByteArrayBlock) block; + int positionCount = byteArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, byteArrayBlock); - if (!block.mayHaveNull()) { - sliceOutput.writeBytes(getValuesSlice(block)); + if (!byteArrayBlock.mayHaveNull()) { + sliceOutput.writeBytes(byteArrayBlock.getValuesSlice()); } else { byte[] valuesWithoutNull = new byte[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getByte(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = byteArrayBlock.getByte(i); + if (!byteArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -61,7 +61,7 @@ public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceO } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public ByteArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); @@ -105,16 +105,4 @@ else if (packed != -1) { // At least one non-null } return new ByteArrayBlock(0, positionCount, valueIsNull, values); } - - private Slice getValuesSlice(Block block) - { - if (block instanceof ByteArrayBlock) { - return ((ByteArrayBlock) block).getValuesSlice(); - } - if (block instanceof ByteArrayBlockBuilder) { - return ((ByteArrayBlockBuilder) block).getValuesSlice(); - } - - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java index 44f51b1e9e30..d88feeb95c39 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java @@ -17,6 +17,7 @@ import java.util.List; +import static io.trino.spi.block.DictionaryId.randomDictionaryId; import static java.util.Objects.requireNonNull; public final class ColumnarRow @@ -103,11 +104,12 @@ private static ColumnarRow toColumnarRowFromDictionaryWithoutNulls(DictionaryBlo Block[] fields = new Block[columnarRow.getFieldCount()]; for (int i = 0; i < fields.length; i++) { // Reuse the dictionary ids array directly since no nulls are present - fields[i] = new DictionaryBlock( + fields[i] = DictionaryBlock.createInternal( dictionaryBlock.getRawIdsOffset(), dictionaryBlock.getPositionCount(), columnarRow.getField(i), - dictionaryBlock.getRawIds()); + dictionaryBlock.getRawIds(), + randomDictionaryId()); } return new ColumnarRow(dictionaryBlock.getPositionCount(), null, fields); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java index dd28bbe541c6..cfb7a67de253 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import java.util.ArrayList; import java.util.Arrays; @@ -33,14 +32,14 @@ import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; -public class DictionaryBlock +public final class DictionaryBlock implements Block { private static final int INSTANCE_SIZE = instanceSize(DictionaryBlock.class) + instanceSize(DictionaryId.class); private static final int NULL_NOT_FOUND = -1; private final int positionCount; - private final Block dictionary; + private final ValueBlock dictionary; private final int idsOffset; private final int[] ids; private final long retainedSizeInBytes; @@ -54,7 +53,7 @@ public class DictionaryBlock public static Block create(int positionCount, Block dictionary, int[] ids) { - return createInternal(positionCount, dictionary, ids, randomDictionaryId()); + return createInternal(0, positionCount, dictionary, ids, randomDictionaryId()); } /** @@ -62,16 +61,16 @@ public static Block create(int positionCount, Block dictionary, int[] ids) */ public static Block createProjectedDictionaryBlock(int positionCount, Block dictionary, int[] ids, DictionaryId dictionarySourceId) { - return createInternal(positionCount, dictionary, ids, dictionarySourceId); + return createInternal(0, positionCount, dictionary, ids, dictionarySourceId); } - private static Block createInternal(int positionCount, Block dictionary, int[] ids, DictionaryId dictionarySourceId) + static Block createInternal(int idsOffset, int positionCount, Block dictionary, int[] ids, DictionaryId dictionarySourceId) { if (positionCount == 0) { return dictionary.copyRegion(0, 0); } if (positionCount == 1) { - return dictionary.getRegion(ids[0], 1); + return dictionary.getRegion(ids[idsOffset], 1); } // if dictionary is an RLE then this can just be a new RLE @@ -79,25 +78,19 @@ private static Block createInternal(int positionCount, Block dictionary, int[] i return RunLengthEncodedBlock.create(rle.getValue(), positionCount); } - // unwrap dictionary in dictionary - if (dictionary instanceof DictionaryBlock dictionaryBlock) { - int[] newIds = new int[positionCount]; - for (int position = 0; position < positionCount; position++) { - newIds[position] = dictionaryBlock.getId(ids[position]); - } - dictionary = dictionaryBlock.getDictionary(); - dictionarySourceId = randomDictionaryId(); - ids = newIds; + if (dictionary instanceof ValueBlock valueBlock) { + return new DictionaryBlock(idsOffset, positionCount, valueBlock, ids, false, false, dictionarySourceId); } - return new DictionaryBlock(0, positionCount, dictionary, ids, false, false, dictionarySourceId); - } - DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[] ids) - { - this(idsOffset, positionCount, dictionary, ids, false, false, randomDictionaryId()); + // unwrap dictionary in dictionary + int[] newIds = new int[positionCount]; + for (int position = 0; position < positionCount; position++) { + newIds[position] = dictionary.getUnderlyingValuePosition(ids[idsOffset + position]); + } + return new DictionaryBlock(0, positionCount, dictionary.getUnderlyingValueBlock(), newIds, false, false, randomDictionaryId()); } - private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[] ids, boolean dictionaryIsCompacted, boolean isSequentialIds, DictionaryId dictionarySourceId) + private DictionaryBlock(int idsOffset, int positionCount, ValueBlock dictionary, int[] ids, boolean dictionaryIsCompacted, boolean isSequentialIds, DictionaryId dictionarySourceId) { requireNonNull(dictionary, "dictionary is null"); requireNonNull(ids, "ids is null"); @@ -130,12 +123,12 @@ private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[ this.isSequentialIds = isSequentialIds; } - int[] getRawIds() + public int[] getRawIds() { return ids; } - int getRawIdsOffset() + public int getRawIdsOffset() { return idsOffset; } @@ -176,12 +169,6 @@ public Slice getSlice(int position, int offset, int length) return dictionary.getSlice(getId(position), offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - dictionary.writeSliceTo(getId(position), offset, length, output); - } - @Override public T getObject(int position, Class clazz) { @@ -189,37 +176,7 @@ public T getObject(int position, Class clazz) } @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - return dictionary.bytesEqual(getId(position), offset, otherSlice, otherOffset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - return dictionary.bytesCompare(getId(position), offset, length, otherSlice, otherOffset, otherLength); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - return dictionary.equals(getId(position), offset, otherBlock, otherPosition, otherOffset, length); - } - - @Override - public long hash(int position, int offset, int length) - { - return dictionary.hash(getId(position), offset, length); - } - - @Override - public int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - return dictionary.compareTo(getId(leftPosition), leftOffset, leftLength, rightBlock, rightPosition, rightOffset, rightLength); - } - - @Override - public Block getSingleValueBlock(int position) + public ValueBlock getSingleValueBlock(int position) { return dictionary.getSingleValueBlock(getId(position)); } @@ -431,7 +388,7 @@ public Block copyPositions(int[] positions, int offset, int length) } newIds[i] = newId; } - Block compactDictionary = dictionary.copyPositions(positionsToCopy.elements(), 0, positionsToCopy.size()); + ValueBlock compactDictionary = dictionary.copyPositions(positionsToCopy.elements(), 0, positionsToCopy.size()); if (positionsToCopy.size() == length) { // discovered that all positions are unique, so return the unwrapped underlying dictionary directly return compactDictionary; @@ -534,7 +491,7 @@ public Block copyWithAppendedNull() { int desiredLength = idsOffset + positionCount + 1; int[] newIds = Arrays.copyOf(ids, desiredLength); - Block newDictionary = dictionary; + ValueBlock newDictionary = dictionary; int nullIndex = NULL_NOT_FOUND; @@ -569,29 +526,24 @@ public String toString() } @Override - public boolean isLoaded() + public List getChildren() { - return dictionary.isLoaded(); + return singletonList(getDictionary()); } @Override - public Block getLoadedBlock() + public ValueBlock getUnderlyingValueBlock() { - Block loadedDictionary = dictionary.getLoadedBlock(); - - if (loadedDictionary == dictionary) { - return this; - } - return new DictionaryBlock(idsOffset, getPositionCount(), loadedDictionary, ids, false, false, randomDictionaryId()); + return dictionary; } @Override - public final List getChildren() + public int getUnderlyingValuePosition(int position) { - return singletonList(getDictionary()); + return getId(position); } - public Block getDictionary() + public ValueBlock getDictionary() { return dictionary; } @@ -675,7 +627,7 @@ public DictionaryBlock compact() newIds[i] = newId; } try { - Block compactDictionary = dictionary.copyPositions(dictionaryPositionsToCopy.elements(), 0, dictionaryPositionsToCopy.size()); + ValueBlock compactDictionary = dictionary.copyPositions(dictionaryPositionsToCopy.elements(), 0, dictionaryPositionsToCopy.size()); return new DictionaryBlock( 0, positionCount, @@ -736,13 +688,13 @@ public static List compactRelatedBlocks(List b } try { - Block compactDictionary = dictionaryBlock.getDictionary().copyPositions(dictionaryPositionsToCopy, 0, numberOfIndexes); + ValueBlock compactDictionary = dictionaryBlock.getDictionary().copyPositions(dictionaryPositionsToCopy, 0, numberOfIndexes); outputDictionaryBlocks.add(new DictionaryBlock( 0, positionCount, compactDictionary, newIds, - !(compactDictionary instanceof DictionaryBlock), + true, false, newDictionaryId)); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java index b716a5a6935a..f38e059ad2f6 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java @@ -28,8 +28,8 @@ import static io.trino.spi.block.BlockUtil.copyIsNullAndAppendNull; import static io.trino.spi.block.BlockUtil.ensureCapacity; -public class Fixed12Block - implements Block +public final class Fixed12Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(Fixed12Block.class); public static final int FIXED12_BYTES = Long.BYTES + Integer.BYTES; @@ -127,12 +127,11 @@ public int getPositionCount() @Override public long getLong(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { // If needed, we can add support for offset 4 throw new IllegalArgumentException("offset must be 0"); } - return decodeFixed12First(values, position + positionOffset); + return getFixed12First(position); } @Override @@ -151,6 +150,17 @@ public int getInt(int position, int offset) throw new IllegalArgumentException("offset must be 0, 4, or 8"); } + public long getFixed12First(int position) + { + checkReadablePosition(this, position); + return decodeFixed12First(values, position + positionOffset); + } + + public int getFixed12Second(int position) + { + return decodeFixed12Second(values, position + positionOffset); + } + @Override public boolean mayHaveNull() { @@ -165,7 +175,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public Fixed12Block getSingleValueBlock(int position) { checkReadablePosition(this, position); int index = (position + positionOffset) * 3; @@ -177,7 +187,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public Fixed12Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -202,7 +212,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public Fixed12Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -210,7 +220,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public Fixed12Block copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -231,13 +241,19 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public Fixed12Block copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, positionOffset, positionCount); int[] newValues = ensureCapacity(values, (positionOffset + positionCount + 1) * 3); return new Fixed12Block(positionOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public Fixed12Block getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java index 1ad36b3fe132..f0d9e278510f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java @@ -90,6 +90,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public Fixed12Block buildValueBlock() + { return new Fixed12Block(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -149,9 +155,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - int[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java index e33239743cf7..131837f74c86 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java @@ -33,30 +33,23 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + Fixed12Block fixed12Block = (Fixed12Block) block; + int positionCount = fixed12Block.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, fixed12Block); - if (!block.mayHaveNull()) { - if (block instanceof Fixed12Block valueBlock) { - sliceOutput.writeInts(valueBlock.getRawValues(), valueBlock.getPositionOffset() * 3, valueBlock.getPositionCount() * 3); - } - else if (block instanceof Fixed12BlockBuilder blockBuilder) { - sliceOutput.writeInts(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount() * 3); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!fixed12Block.mayHaveNull()) { + sliceOutput.writeInts(fixed12Block.getRawValues(), fixed12Block.getPositionOffset() * 3, fixed12Block.getPositionCount() * 3); } else { int[] valuesWithoutNull = new int[positionCount * 3]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getInt(i, 0); - valuesWithoutNull[nonNullPositionCount + 1] = block.getInt(i, 4); - valuesWithoutNull[nonNullPositionCount + 2] = block.getInt(i, 8); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = fixed12Block.getInt(i, 0); + valuesWithoutNull[nonNullPositionCount + 1] = fixed12Block.getInt(i, 4); + valuesWithoutNull[nonNullPositionCount + 2] = fixed12Block.getInt(i, 8); + if (!fixed12Block.isNull(i)) { nonNullPositionCount += 3; } } @@ -67,7 +60,7 @@ else if (block instanceof Fixed12BlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public Fixed12Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java index 311fd731982f..57b8e4ac9bd4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java @@ -13,6 +13,7 @@ */ package io.trino.spi.block; +import io.trino.spi.type.Int128; import jakarta.annotation.Nullable; import java.util.Optional; @@ -28,8 +29,8 @@ import static io.trino.spi.block.BlockUtil.copyIsNullAndAppendNull; import static io.trino.spi.block.BlockUtil.ensureCapacity; -public class Int128ArrayBlock - implements Block +public final class Int128ArrayBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(Int128ArrayBlock.class); public static final int INT128_BYTES = Long.BYTES + Long.BYTES; @@ -127,16 +128,34 @@ public int getPositionCount() @Override public long getLong(int position, int offset) { - checkReadablePosition(this, position); if (offset == 0) { - return values[(position + positionOffset) * 2]; + return getInt128High(position); } if (offset == 8) { - return values[((position + positionOffset) * 2) + 1]; + return getInt128Low(position); } throw new IllegalArgumentException("offset must be 0 or 8"); } + public Int128 getInt128(int position) + { + checkReadablePosition(this, position); + int offset = (position + positionOffset) * 2; + return Int128.valueOf(values[offset], values[offset + 1]); + } + + public long getInt128High(int position) + { + checkReadablePosition(this, position); + return values[(position + positionOffset) * 2]; + } + + public long getInt128Low(int position) + { + checkReadablePosition(this, position); + return values[((position + positionOffset) * 2) + 1]; + } + @Override public boolean mayHaveNull() { @@ -151,7 +170,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public Int128ArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new Int128ArrayBlock( @@ -164,7 +183,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public Int128ArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -186,7 +205,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public Int128ArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -194,7 +213,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public Int128ArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -215,13 +234,19 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public Int128ArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, positionOffset, positionCount); long[] newValues = ensureCapacity(values, (positionOffset + positionCount + 1) * 2); return new Int128ArrayBlock(positionOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public Int128ArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java index a3b7dc78dff1..f22ae8951fea 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java @@ -90,6 +90,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public Int128ArrayBlock buildValueBlock() + { return new Int128ArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -149,9 +155,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - long[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java index 889d7814716c..78e8191202e5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java @@ -33,29 +33,22 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + Int128ArrayBlock int128ArrayBlock = (Int128ArrayBlock) block; + int positionCount = int128ArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, int128ArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof Int128ArrayBlock valueBlock) { - sliceOutput.writeLongs(valueBlock.getRawValues(), valueBlock.getPositionOffset() * 2, valueBlock.getPositionCount() * 2); - } - else if (block instanceof Int128ArrayBlockBuilder blockBuilder) { - sliceOutput.writeLongs(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount() * 2); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!int128ArrayBlock.mayHaveNull()) { + sliceOutput.writeLongs(int128ArrayBlock.getRawValues(), int128ArrayBlock.getPositionOffset() * 2, int128ArrayBlock.getPositionCount() * 2); } else { long[] valuesWithoutNull = new long[positionCount * 2]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getLong(i, 0); - valuesWithoutNull[nonNullPositionCount + 1] = block.getLong(i, 8); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = int128ArrayBlock.getInt128High(i); + valuesWithoutNull[nonNullPositionCount + 1] = int128ArrayBlock.getInt128Low(i); + if (!int128ArrayBlock.isNull(i)) { nonNullPositionCount += 2; } } @@ -66,7 +59,7 @@ else if (block instanceof Int128ArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public Int128ArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java index 2160d585d96c..93fa86da8456 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java @@ -29,8 +29,8 @@ import static io.trino.spi.block.BlockUtil.copyIsNullAndAppendNull; import static io.trino.spi.block.BlockUtil.ensureCapacity; -public class IntArrayBlock - implements Block +public final class IntArrayBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(IntArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Integer.BYTES + Byte.BYTES; @@ -127,10 +127,15 @@ public int getPositionCount() @Override public int getInt(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getInt(position); + } + + public int getInt(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -148,7 +153,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public IntArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new IntArrayBlock( @@ -159,7 +164,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public IntArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -180,7 +185,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public IntArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -188,7 +193,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public IntArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -209,7 +214,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public IntArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); int[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); @@ -217,6 +222,12 @@ public Block copyWithAppendedNull() return new IntArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public IntArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java index 52d8ae115b0c..bf124103418b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java @@ -89,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public IntArrayBlock buildValueBlock() + { return new IntArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -148,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - int[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java index ffcf3b87060c..408475020e9a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java @@ -35,28 +35,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + IntArrayBlock intArrayBlock = (IntArrayBlock) block; + int positionCount = intArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, intArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof IntArrayBlock valueBlock) { - sliceOutput.writeInts(valueBlock.getRawValues(), valueBlock.getRawValuesOffset(), valueBlock.getPositionCount()); - } - else if (block instanceof IntArrayBlockBuilder blockBuilder) { - sliceOutput.writeInts(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount()); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!intArrayBlock.mayHaveNull()) { + sliceOutput.writeInts(intArrayBlock.getRawValues(), intArrayBlock.getRawValuesOffset(), intArrayBlock.getPositionCount()); } else { int[] valuesWithoutNull = new int[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getInt(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = intArrayBlock.getInt(i); + if (!intArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -67,7 +60,7 @@ else if (block instanceof IntArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public IntArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java index dc6fb5f00ddb..8579f7ca93fb 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import jakarta.annotation.Nullable; import java.util.ArrayList; @@ -31,7 +30,7 @@ import static java.util.Objects.requireNonNull; // This class is not considered thread-safe. -public class LazyBlock +public final class LazyBlock implements Block { private static final int INSTANCE_SIZE = instanceSize(LazyBlock.class) + instanceSize(LazyData.class); @@ -87,12 +86,6 @@ public Slice getSlice(int position, int offset, int length) return getBlock().getSlice(position, offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - getBlock().writeSliceTo(position, offset, length, output); - } - @Override public T getObject(int position, Class clazz) { @@ -100,56 +93,7 @@ public T getObject(int position, Class clazz) } @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - return getBlock().bytesEqual(position, offset, otherSlice, otherOffset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - return getBlock().bytesCompare( - position, - offset, - length, - otherSlice, - otherOffset, - otherLength); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - return getBlock().equals( - position, - offset, - otherBlock, - otherPosition, - otherOffset, - length); - } - - @Override - public long hash(int position, int offset, int length) - { - return getBlock().hash(position, offset, length); - } - - @Override - public int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - return getBlock().compareTo( - leftPosition, - leftOffset, - leftLength, - rightBlock, - rightPosition, - rightOffset, - rightLength); - } - - @Override - public Block getSingleValueBlock(int position) + public ValueBlock getSingleValueBlock(int position) { return getBlock().getSingleValueBlock(position); } @@ -269,7 +213,7 @@ public boolean mayHaveNull() } @Override - public final List getChildren() + public List getChildren() { return singletonList(getBlock()); } @@ -291,6 +235,18 @@ public Block getLoadedBlock() return lazyData.getFullyLoadedBlock(); } + @Override + public ValueBlock getUnderlyingValueBlock() + { + return getBlock().getUnderlyingValueBlock(); + } + + @Override + public int getUnderlyingValuePosition(int position) + { + return getBlock().getUnderlyingValuePosition(position); + } + public static void listenForLoads(Block block, Consumer listener) { requireNonNull(block, "block is null"); @@ -434,7 +390,7 @@ private void load(boolean recursive) } /** - * If block is unloaded, add the listeners; otherwise call this method on child blocks + * If the block is unloaded, add the listeners; otherwise call this method on child blocks */ @SuppressWarnings("AccessingNonPublicFieldOfAnotherObject") private static void addListenersRecursive(Block block, List> listeners) diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java index a472cab833f3..99a3df02b65a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java @@ -23,8 +23,6 @@ public class LazyBlockEncoding { public static final String NAME = "LAZY"; - public LazyBlockEncoding() {} - @Override public String getName() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java index abba78907863..2b9aec633844 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java @@ -29,8 +29,8 @@ import static io.trino.spi.block.BlockUtil.ensureCapacity; import static java.lang.Math.toIntExact; -public class LongArrayBlock - implements Block +public final class LongArrayBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(LongArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Long.BYTES + Byte.BYTES; @@ -127,10 +127,15 @@ public int getPositionCount() @Override public long getLong(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getLong(position); + } + + public long getLong(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -194,7 +199,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public LongArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new LongArrayBlock( @@ -205,7 +210,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public LongArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -226,7 +231,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public LongArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -234,7 +239,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public LongArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -255,7 +260,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public LongArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); long[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); @@ -263,6 +268,12 @@ public Block copyWithAppendedNull() return new LongArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public LongArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java index eaa16a21057b..09a530971ac1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java @@ -89,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public LongArrayBlock buildValueBlock() + { return new LongArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -148,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - long[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java index 0d6ce7d14679..5167fca68087 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java @@ -35,28 +35,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + LongArrayBlock longArrayBlock = (LongArrayBlock) block; + int positionCount = longArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, longArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof LongArrayBlock valueBlock) { - sliceOutput.writeLongs(valueBlock.getRawValues(), valueBlock.getRawValuesOffset(), valueBlock.getPositionCount()); - } - else if (block instanceof LongArrayBlockBuilder blockBuilder) { - sliceOutput.writeLongs(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount()); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!longArrayBlock.mayHaveNull()) { + sliceOutput.writeLongs(longArrayBlock.getRawValues(), longArrayBlock.getRawValuesOffset(), longArrayBlock.getPositionCount()); } else { long[] valuesWithoutNull = new long[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getLong(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = longArrayBlock.getLong(i); + if (!longArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -67,7 +60,7 @@ else if (block instanceof LongArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public LongArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java index da07f1521019..d7c78df45d46 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java @@ -40,8 +40,8 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -public class MapBlock - implements Block +public final class MapBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(MapBlock.class); @@ -187,27 +187,27 @@ private MapBlock( this.retainedSizeInBytes = INSTANCE_SIZE + sizeOf(offsets) + sizeOf(mapIsNull); } - protected Block getRawKeyBlock() + Block getRawKeyBlock() { return keyBlock; } - protected Block getRawValueBlock() + Block getRawValueBlock() { return valueBlock; } - protected MapHashTables getHashTables() + MapHashTables getHashTables() { return hashTables; } - protected int[] getOffsets() + int[] getOffsets() { return offsets; } - protected int getOffsetBase() + int getOffsetBase() { return startOffset; } @@ -302,13 +302,13 @@ public Block getLoadedBlock() hashTables); } - protected void ensureHashTableLoaded() + void ensureHashTableLoaded() { hashTables.buildAllHashTablesIfNecessary(keyBlock, offsets, mapIsNull); } @Override - public Block copyWithAppendedNull() + public MapBlock copyWithAppendedNull() { boolean[] newMapIsNull = copyIsNullAndAppendNull(mapIsNull, startOffset, getPositionCount()); int[] newOffsets = copyOffsetsAndAppendNull(offsets, startOffset, getPositionCount()); @@ -325,12 +325,12 @@ public Block copyWithAppendedNull() } @Override - public final List getChildren() + public List getChildren() { return List.of(keyBlock, valueBlock); } - protected MapType getMapType() + MapType getMapType() { return mapType; } @@ -347,7 +347,7 @@ public String getEncodingName() } @Override - public Block copyPositions(int[] positions, int offset, int length) + public MapBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -407,7 +407,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int position, int length) + public MapBlock getRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -460,7 +460,7 @@ private OptionalInt keyAndValueFixedSizeInBytesPerRow() } @Override - public final long getPositionsSizeInBytes(boolean[] positions, int selectedMapPositions) + public long getPositionsSizeInBytes(boolean[] positions, int selectedMapPositions) { int positionCount = getPositionCount(); checkValidPositions(positions, positionCount); @@ -500,7 +500,7 @@ public final long getPositionsSizeInBytes(boolean[] positions, int selectedMapPo } @Override - public Block copyRegion(int position, int length) + public MapBlock copyRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -541,21 +541,25 @@ public T getObject(int position, Class clazz) if (clazz != SqlMap.class) { throw new IllegalArgumentException("clazz must be SqlMap.class"); } - checkReadablePosition(this, position); + return clazz.cast(getMap(position)); + } + public SqlMap getMap(int position) + { + checkReadablePosition(this, position); int startEntryOffset = getOffset(position); int endEntryOffset = getOffset(position + 1); - return clazz.cast(new SqlMap( + return new SqlMap( mapType, keyBlock, valueBlock, new SqlMap.HashTableSupplier(this), startEntryOffset, - (endEntryOffset - startEntryOffset))); + (endEntryOffset - startEntryOffset)); } @Override - public Block getSingleValueBlock(int position) + public MapBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); @@ -611,6 +615,12 @@ public boolean isNull(int position) return mapIsNull != null && mapIsNull[position + startOffset]; } + @Override + public MapBlock getUnderlyingValueBlock() + { + return this; + } + // only visible for testing public boolean isHashTablesPresent() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java index ba0f9e819e72..477ae48f4ce1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java @@ -166,6 +166,12 @@ private void entryAdded(boolean isNull) @Override public Block build() + { + return buildValueBlock(); + } + + @Override + public MapBlock buildValueBlock() { if (currentEntryOpened) { throw new IllegalStateException("Current entry must be closed before the block can be built"); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java index 90efc72c22cb..ba373813e9d8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java @@ -35,8 +35,8 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -public class RowBlock - implements Block +public final class RowBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(RowBlock.class); private final int numFields; @@ -55,7 +55,7 @@ public class RowBlock /** * Create a row block directly from columnar nulls and field blocks. */ - public static Block fromFieldBlocks(int positionCount, Optional rowIsNullOptional, Block[] fieldBlocks) + public static RowBlock fromFieldBlocks(int positionCount, Optional rowIsNullOptional, Block[] fieldBlocks) { boolean[] rowIsNull = rowIsNullOptional.orElse(null); int[] fieldBlockOffsets = null; @@ -144,18 +144,18 @@ private RowBlock(int startOffset, int positionCount, @Nullable boolean[] rowIsNu this.retainedSizeInBytes = INSTANCE_SIZE + sizeOf(fieldBlockOffsets) + sizeOf(rowIsNull); } - protected List getRawFieldBlocks() + List getRawFieldBlocks() { return fieldBlocksList; } @Nullable - protected int[] getFieldBlockOffsets() + int[] getFieldBlockOffsets() { return fieldBlockOffsets; } - protected int getOffsetBase() + int getOffsetBase() { return startOffset; } @@ -256,7 +256,7 @@ public Block getLoadedBlock() } @Override - public Block copyWithAppendedNull() + public RowBlock copyWithAppendedNull() { boolean[] newRowIsNull = copyIsNullAndAppendNull(rowIsNull, startOffset, getPositionCount()); @@ -281,13 +281,13 @@ public Block copyWithAppendedNull() } @Override - public final List getChildren() + public List getChildren() { return List.of(fieldBlocks); } // the offset in each field block, it can also be viewed as the "entry-based" offset in the RowBlock - public final int getFieldBlockOffset(int position) + public int getFieldBlockOffset(int position) { int[] offsets = fieldBlockOffsets; if (offsets != null) { @@ -305,7 +305,7 @@ public String getEncodingName() } @Override - public Block copyPositions(int[] positions, int offset, int length) + public RowBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -354,7 +354,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int position, int length) + public RowBlock getRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -363,14 +363,14 @@ public Block getRegion(int position, int length) } @Override - public final OptionalInt fixedSizeInBytesPerPosition() + public OptionalInt fixedSizeInBytesPerPosition() { if (!mayHaveNull()) { // when null rows are present, we can't use the fixed field sizes to infer the correct // size for arbitrary position selection OptionalInt fieldSize = fixedSizeInBytesPerFieldPosition(); if (fieldSize.isPresent()) { - // must include the row block overhead in addition to the per position size in bytes + // must include the row block overhead in addition to the per-position size in bytes return OptionalInt.of(fieldSize.getAsInt() + (Integer.BYTES + Byte.BYTES)); // offsets + rowIsNull } } @@ -415,7 +415,7 @@ public long getRegionSizeInBytes(int position, int length) } @Override - public final long getPositionsSizeInBytes(boolean[] positions, int selectedRowPositions) + public long getPositionsSizeInBytes(boolean[] positions, int selectedRowPositions) { int positionCount = getPositionCount(); checkValidPositions(positions, positionCount); @@ -432,7 +432,7 @@ public final long getPositionsSizeInBytes(boolean[] positions, int selectedRowPo int selectedFieldPositionCount = selectedRowPositions; boolean[] rowIsNull = this.rowIsNull; if (rowIsNull != null) { - // Some positions in usedPositions may be null which must be removed from the selectedFieldPositionCount + // Some positions of usedPositions may be null, and these must be removed from the selectedFieldPositionCount int offsetBase = startOffset; for (int i = 0; i < positions.length; i++) { if (positions[i] && rowIsNull[i + offsetBase]) { @@ -492,7 +492,7 @@ private long getSpecificPositionsSizeInBytes(boolean[] positions, int selectedRo } @Override - public Block copyRegion(int position, int length) + public RowBlock copyRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -533,13 +533,17 @@ public T getObject(int position, Class clazz) if (clazz != SqlRow.class) { throw new IllegalArgumentException("clazz must be SqlRow.class"); } - checkReadablePosition(this, position); + return clazz.cast(getRow(position)); + } - return clazz.cast(new SqlRow(getFieldBlockOffset(position), fieldBlocks)); + public SqlRow getRow(int position) + { + checkReadablePosition(this, position); + return new SqlRow(getFieldBlockOffset(position), fieldBlocks); } @Override - public Block getSingleValueBlock(int position) + public RowBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); @@ -583,4 +587,10 @@ public boolean isNull(int position) } return rowIsNull[position + startOffset]; } + + @Override + public RowBlock getUnderlyingValueBlock() + { + return this; + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java index 77e2d70d5825..cbbbb65e869b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java @@ -164,6 +164,16 @@ public Block build() if (!hasNonNullRow) { return nullRle(positionCount); } + return buildValueBlock(); + } + + @Override + public RowBlock buildValueBlock() + { + if (currentEntryOpened) { + throw new IllegalStateException("Current entry must be closed before the block can be built"); + } + Block[] fieldBlocks = new Block[fieldBlockBuilders.length]; for (int i = 0; i < fieldBlockBuilders.length; i++) { fieldBlocks[i] = fieldBlockBuilders[i].build(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java index 996d35398c87..6ad42a33d77c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import io.trino.spi.predicate.Utils; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -32,7 +31,7 @@ import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; -public class RunLengthEncodedBlock +public final class RunLengthEncodedBlock implements Block { private static final int INSTANCE_SIZE = instanceSize(RunLengthEncodedBlock.class); @@ -59,13 +58,25 @@ public static Block create(Block value, int positionCount) if (positionCount == 1) { return value; } - return new RunLengthEncodedBlock(value, positionCount); + + if (value instanceof ValueBlock valueBlock) { + return new RunLengthEncodedBlock(valueBlock, positionCount); + } + + // unwrap the value + ValueBlock valueBlock = value.getUnderlyingValueBlock(); + int valuePosition = value.getUnderlyingValuePosition(0); + if (valueBlock.getPositionCount() == 1 && valuePosition == 0) { + return new RunLengthEncodedBlock(valueBlock, positionCount); + } + + return new RunLengthEncodedBlock(valueBlock.getRegion(valuePosition, 1), positionCount); } - private final Block value; + private final ValueBlock value; private final int positionCount; - private RunLengthEncodedBlock(Block value, int positionCount) + private RunLengthEncodedBlock(ValueBlock value, int positionCount) { requireNonNull(value, "value is null"); if (positionCount < 0) { @@ -75,40 +86,23 @@ private RunLengthEncodedBlock(Block value, int positionCount) throw new IllegalArgumentException("positionCount must be at least 2"); } - // do not nest an RLE or Dictionary in an RLE - if (value instanceof RunLengthEncodedBlock block) { - this.value = block.getValue(); - } - else if (value instanceof DictionaryBlock block) { - Block dictionary = block.getDictionary(); - int id = block.getId(0); - if (dictionary.getPositionCount() == 1 && id == 0) { - this.value = dictionary; - } - else { - this.value = dictionary.getRegion(id, 1); - } - } - else { - this.value = value; - } - + this.value = value; this.positionCount = positionCount; } @Override - public final List getChildren() + public List getChildren() { return singletonList(value); } - public Block getValue() + public ValueBlock getValue() { return value; } /** - * Positions count will always be at least 2 + * Position count will always be at least 2 */ @Override public int getPositionCount() @@ -247,13 +241,6 @@ public Slice getSlice(int position, int offset, int length) return value.getSlice(0, offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - checkReadablePosition(this, position); - value.writeSliceTo(0, offset, length, output); - } - @Override public T getObject(int position, Class clazz) { @@ -262,42 +249,7 @@ public T getObject(int position, Class clazz) } @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - checkReadablePosition(this, position); - return value.bytesEqual(0, offset, otherSlice, otherOffset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - checkReadablePosition(this, position); - return value.bytesCompare(0, offset, length, otherSlice, otherOffset, otherLength); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - checkReadablePosition(this, position); - return value.equals(0, offset, otherBlock, otherPosition, otherOffset, length); - } - - @Override - public long hash(int position, int offset, int length) - { - checkReadablePosition(this, position); - return value.hash(0, offset, length); - } - - @Override - public int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - checkReadablePosition(this, leftPosition); - return value.compareTo(0, leftOffset, leftLength, rightBlock, rightPosition, rightOffset, rightLength); - } - - @Override - public Block getSingleValueBlock(int position) + public ValueBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return value; @@ -323,7 +275,7 @@ public Block copyWithAppendedNull() return create(value, positionCount + 1); } - Block dictionary = value.copyWithAppendedNull(); + ValueBlock dictionary = value.copyWithAppendedNull(); int[] ids = new int[positionCount + 1]; ids[positionCount] = 1; return DictionaryBlock.create(ids.length, dictionary, ids); @@ -340,19 +292,14 @@ public String toString() } @Override - public boolean isLoaded() + public ValueBlock getUnderlyingValueBlock() { - return value.isLoaded(); + return value; } @Override - public Block getLoadedBlock() + public int getUnderlyingValuePosition(int position) { - Block loadedValueBlock = value.getLoadedBlock(); - - if (loadedValueBlock == value) { - return this; - } - return create(loadedValueBlock, positionCount); + return 0; } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java index a34ac450556a..336a5b845539 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java @@ -28,8 +28,8 @@ import static io.trino.spi.block.BlockUtil.copyIsNullAndAppendNull; import static io.trino.spi.block.BlockUtil.ensureCapacity; -public class ShortArrayBlock - implements Block +public final class ShortArrayBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(ShortArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Short.BYTES + Byte.BYTES; @@ -126,10 +126,15 @@ public int getPositionCount() @Override public short getShort(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getShort(position); + } + + public short getShort(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -147,7 +152,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public ShortArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new ShortArrayBlock( @@ -158,7 +163,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public ShortArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -179,7 +184,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public ShortArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -187,7 +192,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public ShortArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -208,13 +213,19 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public ShortArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); short[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); return new ShortArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public ShortArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java index aa3db4bcf4b1..ee44b44b6dc2 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java @@ -89,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public ShortArrayBlock buildValueBlock() + { return new ShortArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -148,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - short[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java index 0aa79f278376..15813a428f74 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java @@ -35,28 +35,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + ShortArrayBlock shortArrayBlock = (ShortArrayBlock) block; + int positionCount = shortArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, shortArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof ShortArrayBlock valueBlock) { - sliceOutput.writeShorts(valueBlock.getRawValues(), valueBlock.getRawValuesOffset(), valueBlock.getPositionCount()); - } - else if (block instanceof ShortArrayBlockBuilder blockBuilder) { - sliceOutput.writeShorts(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount()); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!shortArrayBlock.mayHaveNull()) { + sliceOutput.writeShorts(shortArrayBlock.getRawValues(), shortArrayBlock.getRawValuesOffset(), shortArrayBlock.getPositionCount()); } else { short[] valuesWithoutNull = new short[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getShort(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = shortArrayBlock.getShort(i); + if (!shortArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -67,7 +60,7 @@ else if (block instanceof ShortArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public ShortArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java b/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java index 3a9335a05cbd..29f4bc1988c4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java @@ -98,6 +98,26 @@ public Block getRawValueBlock() return rawValueBlock; } + public int getUnderlyingKeyPosition(int position) + { + return rawKeyBlock.getUnderlyingValuePosition(offset + position); + } + + public ValueBlock getUnderlyingKeyBlock() + { + return rawKeyBlock.getUnderlyingValueBlock(); + } + + public int getUnderlyingValuePosition(int position) + { + return rawValueBlock.getUnderlyingValuePosition(offset + position); + } + + public ValueBlock getUnderlyingValueBlock() + { + return rawValueBlock.getUnderlyingValueBlock(); + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java b/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java index 2acc019fe596..fcb7e2950e86 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java @@ -82,6 +82,16 @@ public void retainedBytesForEachPart(ObjLongConsumer consumer) consumer.accept(this, INSTANCE_SIZE); } + public int getUnderlyingFieldPosition(int fieldIndex) + { + return fieldBlocks[fieldIndex].getUnderlyingValuePosition(rawIndex); + } + + public ValueBlock getUnderlyingFieldBlock(int fieldIndex) + { + return fieldBlocks[fieldIndex].getUnderlyingValueBlock(); + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ValueBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ValueBlock.java new file mode 100644 index 000000000000..500769a29a13 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ValueBlock.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.block; + +public non-sealed interface ValueBlock + extends Block +{ + @Override + ValueBlock copyPositions(int[] positions, int offset, int length); + + @Override + ValueBlock getRegion(int positionOffset, int length); + + @Override + ValueBlock copyRegion(int position, int length); + + @Override + ValueBlock copyWithAppendedNull(); + + @Override + default ValueBlock getUnderlyingValueBlock() + { + return this; + } + + @Override + default int getUnderlyingValuePosition(int position) + { + return position; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java index 9c2d53940554..931d4cae70eb 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java @@ -16,7 +16,6 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; -import io.airlift.slice.XxHash64; import jakarta.annotation.Nullable; import java.util.Optional; @@ -35,8 +34,8 @@ import static io.trino.spi.block.BlockUtil.copyIsNullAndAppendNull; import static io.trino.spi.block.BlockUtil.copyOffsetsAndAppendNull; -public class VariableWidthBlock - implements Block +public final class VariableWidthBlock + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(VariableWidthBlock.class); @@ -102,7 +101,7 @@ public int getRawSliceOffset(int position) return getPositionOffset(position); } - protected final int getPositionOffset(int position) + int getPositionOffset(int position) { return offsets[position + arrayOffset]; } @@ -214,54 +213,12 @@ public Slice getSlice(int position, int offset, int length) return slice.slice(getPositionOffset(position) + offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - checkReadablePosition(this, position); - output.writeBytes(slice, getPositionOffset(position) + offset, length); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - checkReadablePosition(this, position); - Slice rawSlice = slice; - if (getSliceLength(position) < length) { - return false; - } - return otherBlock.bytesEqual(otherPosition, otherOffset, rawSlice, getPositionOffset(position) + offset, length); - } - - @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - checkReadablePosition(this, position); - return slice.equals(getPositionOffset(position) + offset, length, otherSlice, otherOffset, length); - } - - @Override - public long hash(int position, int offset, int length) - { - checkReadablePosition(this, position); - return XxHash64.hash(slice, getPositionOffset(position) + offset, length); - } - - @Override - public int compareTo(int position, int offset, int length, Block otherBlock, int otherPosition, int otherOffset, int otherLength) - { - checkReadablePosition(this, position); - Slice rawSlice = slice; - if (getSliceLength(position) < length) { - throw new IllegalArgumentException("Length longer than value length"); - } - return -otherBlock.bytesCompare(otherPosition, otherOffset, otherLength, rawSlice, getPositionOffset(position) + offset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) + public Slice getSlice(int position) { checkReadablePosition(this, position); - return slice.compareTo(getPositionOffset(position) + offset, length, otherSlice, otherOffset, otherLength); + int offset = offsets[position + arrayOffset]; + int length = offsets[position + 1 + arrayOffset] - offset; + return slice.slice(offset, length); } @Override @@ -278,7 +235,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public VariableWidthBlock getSingleValueBlock(int position) { if (isNull(position)) { return new VariableWidthBlock(0, 1, EMPTY_SLICE, new int[] {0, 0}, new boolean[] {true}); @@ -293,7 +250,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public VariableWidthBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); if (length == 0) { @@ -337,7 +294,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public VariableWidthBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -345,7 +302,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public VariableWidthBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); positionOffset += arrayOffset; @@ -367,7 +324,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public VariableWidthBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); int[] newOffsets = copyOffsetsAndAppendNull(offsets, arrayOffset, positionCount); @@ -375,6 +332,12 @@ public Block copyWithAppendedNull() return new VariableWidthBlock(arrayOffset, positionCount + 1, slice, newOffsets, newValueIsNull); } + @Override + public VariableWidthBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java index 201f0029a4e7..6ec063828cf1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java @@ -190,6 +190,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positions); } + return buildValueBlock(); + } + + @Override + public VariableWidthBlock buildValueBlock() + { return new VariableWidthBlock(0, positions, sliceOutput.slice(), offsets, hasNullValue ? valueIsNull : null); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java index 6218a859aec0..6e8af40a5b44 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java @@ -38,7 +38,6 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - // The down casts here are safe because it is the block itself the provides this encoding implementation. VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; int positionCount = variableWidthBlock.getPositionCount(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index d7dd4e7da624..bbb6202aba72 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -1405,6 +1405,11 @@ default Optional> applyAggreg Map assignments, List> groupingSets) { + // Global aggregation is represented by [[]] + if (groupingSets.isEmpty()) { + throw new IllegalArgumentException("No grouping sets provided"); + } + return Optional.empty(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java index 935488f83f41..a494839c5253 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java @@ -115,6 +115,13 @@ public enum InvocationArgumentConvention * results are undefined. */ BLOCK_POSITION_NOT_NULL(false, 2), + /** + * Argument is passed a ValueBlock followed by the integer position in the block. + * The actual block parameter may be any subtype of ValueBlock, and the scalar function + * adapter will convert the parameter to ValueBlock. If the actual block position + * passed to the function argument is null, the results are undefined. + */ + VALUE_BLOCK_POSITION_NOT_NULL(false, 2), /** * Argument is always an object type. An SQL null will be passed a Java null. */ @@ -125,10 +132,16 @@ public enum InvocationArgumentConvention */ NULL_FLAG(true, 2), /** - * Argument is passed a Block followed by the integer position in the block. The + * Argument is passed a Block followed by the integer position in the block. The * sql value may be null. */ BLOCK_POSITION(true, 2), + /** + * Argument is passed a ValueBlock followed by the integer position in the block. + * The actual block parameter may be any subtype of ValueBlock, and the scalar function + * adapter will convert the parameter to ValueBlock. The sql value may be null. + */ + VALUE_BLOCK_POSITION(true, 2), /** * Argument is passed as a flat slice. The sql value may not be null. */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java b/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java index 1304c318894d..0575a08fe4f5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java @@ -15,6 +15,8 @@ import java.lang.invoke.MethodHandle; +import static java.util.Objects.requireNonNull; + public class OperatorMethodHandle { private final InvocationConvention callingConvention; @@ -22,8 +24,8 @@ public class OperatorMethodHandle public OperatorMethodHandle(InvocationConvention callingConvention, MethodHandle methodHandle) { - this.callingConvention = callingConvention; - this.methodHandle = methodHandle; + this.callingConvention = requireNonNull(callingConvention, "callingConvention is null"); + this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); } public InvocationConvention getCallingConvention() diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index e1de94a16597..27b3e9385181 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -19,12 +19,14 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.util.List; import java.util.Objects; @@ -38,6 +40,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -53,18 +57,43 @@ import static java.lang.invoke.MethodHandles.insertArguments; import static java.lang.invoke.MethodHandles.lookup; import static java.lang.invoke.MethodHandles.permuteArguments; -import static java.lang.invoke.MethodHandles.publicLookup; import static java.lang.invoke.MethodHandles.throwException; import static java.lang.invoke.MethodType.methodType; import static java.util.Objects.requireNonNull; public final class ScalarFunctionAdapter { - private static final MethodHandle IS_NULL_METHOD = lookupIsNullMethod(); - private static final MethodHandle APPEND_NULL_METHOD = lookupAppendNullMethod(); + private static final MethodHandle OBJECT_IS_NULL_METHOD; + private static final MethodHandle APPEND_NULL_METHOD; + private static final MethodHandle BLOCK_IS_NULL_METHOD; + private static final MethodHandle IN_OUT_IS_NULL_METHOD; + private static final MethodHandle GET_UNDERLYING_VALUE_BLOCK_METHOD; + private static final MethodHandle GET_UNDERLYING_VALUE_POSITION_METHOD; + private static final MethodHandle NEW_NEVER_NULL_IS_NULL_EXCEPTION; // This is needed to convert flat arguments to stack types private static final TypeOperators READ_VALUE_TYPE_OPERATORS = new TypeOperators(); + static { + try { + MethodHandles.Lookup lookup = lookup(); + OBJECT_IS_NULL_METHOD = lookup.findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class)); + APPEND_NULL_METHOD = lookup.findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class)) + .asType(methodType(void.class, BlockBuilder.class)); + BLOCK_IS_NULL_METHOD = lookup.findVirtual(Block.class, "isNull", methodType(boolean.class, int.class)); + IN_OUT_IS_NULL_METHOD = lookup.findVirtual(InOut.class, "isNull", methodType(boolean.class)); + + GET_UNDERLYING_VALUE_BLOCK_METHOD = lookup().findVirtual(Block.class, "getUnderlyingValueBlock", methodType(ValueBlock.class)); + GET_UNDERLYING_VALUE_POSITION_METHOD = lookup().findVirtual(Block.class, "getUnderlyingValuePosition", methodType(int.class, int.class)); + + NEW_NEVER_NULL_IS_NULL_EXCEPTION = lookup.findConstructor(TrinoException.class, methodType(void.class, ErrorCodeSupplier.class, String.class)) + .bindTo(StandardErrorCode.INVALID_FUNCTION_ARGUMENT) + .bindTo("A never null argument is null"); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + private ScalarFunctionAdapter() {} /** @@ -136,17 +165,26 @@ private static boolean canAdaptParameter( return switch (actualArgumentConvention) { case NEVER_NULL -> switch (expectedArgumentConvention) { - case BLOCK_POSITION_NOT_NULL, FLAT -> true; + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, FLAT -> true; case BOXED_NULLABLE, NULL_FLAG -> returnConvention != FAIL_ON_NULL; - case BLOCK_POSITION, IN_OUT -> true; // todo only support these if the return convention is nullable + case BLOCK_POSITION, VALUE_BLOCK_POSITION, IN_OUT -> true; // todo only support these if the return convention is nullable case FUNCTION -> throw new IllegalStateException("Unexpected value: " + expectedArgumentConvention); // this is not needed as the case where actual and expected are the same is covered above, // but this means we will get a compile time error if a new convention is added in the future //noinspection DataFlowIssue case NEVER_NULL -> true; }; - case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && (returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL); - case BLOCK_POSITION -> expectedArgumentConvention == BLOCK_POSITION_NOT_NULL; + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL -> switch (expectedArgumentConvention) { + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL -> true; + case BLOCK_POSITION, VALUE_BLOCK_POSITION -> returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL; + case NEVER_NULL, NULL_FLAG, BOXED_NULLABLE, FLAT, IN_OUT -> false; + case FUNCTION -> throw new IllegalStateException("Unexpected value: " + expectedArgumentConvention); + }; + case BLOCK_POSITION, VALUE_BLOCK_POSITION -> switch (expectedArgumentConvention) { + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, BLOCK_POSITION, VALUE_BLOCK_POSITION -> true; + case NEVER_NULL, NULL_FLAG, BOXED_NULLABLE, FLAT, IN_OUT -> false; + case FUNCTION -> throw new IllegalStateException("Unexpected value: " + expectedArgumentConvention); + }; case BOXED_NULLABLE, NULL_FLAG -> true; case FLAT, IN_OUT -> false; case FUNCTION -> throw new IllegalArgumentException("Unsupported argument convention: " + actualArgumentConvention); @@ -263,6 +301,11 @@ private static MethodHandle adaptParameter( InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) { + // For value block, cast specialized parameter to ValueBlock + if (actualArgumentConvention == VALUE_BLOCK_POSITION || actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL && methodHandle.type().parameterType(parameterIndex) != ValueBlock.class) { + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + } + if (actualArgumentConvention == expectedArgumentConvention) { return methodHandle; } @@ -324,7 +367,7 @@ private static MethodHandle adaptParameter( methodHandle = filterArguments( methodHandle, parameterIndex + 1, - explicitCastArguments(IS_NULL_METHOD, methodType(boolean.class, wrap(parameterType)))); + explicitCastArguments(OBJECT_IS_NULL_METHOD, methodType(boolean.class, wrap(parameterType)))); // 1. Duplicate the argument, so we have two copies of the value // Long, Long => Long @@ -359,88 +402,45 @@ private static MethodHandle adaptParameter( } if (expectedArgumentConvention == BLOCK_POSITION_NOT_NULL) { - if (actualArgumentConvention == BLOCK_POSITION) { - return methodHandle; + if (actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL || actualArgumentConvention == VALUE_BLOCK_POSITION) { + return adaptValueBlockArgumentToBlock(methodHandle, parameterIndex); } - MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); - if (actualArgumentConvention == NEVER_NULL) { - return collectArguments(methodHandle, parameterIndex, getBlockValue); - } - if (actualArgumentConvention == BOXED_NULLABLE) { - MethodType targetType = getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType())); - return collectArguments(methodHandle, parameterIndex, explicitCastArguments(getBlockValue, targetType)); - } - if (actualArgumentConvention == NULL_FLAG) { - // actual method takes value and null flag, so change method handles to not have the flag and always pass false to the actual method - return collectArguments(insertArguments(methodHandle, parameterIndex + 1, false), parameterIndex, getBlockValue); + return adaptParameterToBlockPositionNotNull(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + if (expectedArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { + if (actualArgumentConvention == VALUE_BLOCK_POSITION) { + return methodHandle; } + + methodHandle = adaptParameterToBlockPositionNotNull(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + return methodHandle; } // caller passes block and position which may contain a null if (expectedArgumentConvention == BLOCK_POSITION) { - MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); - - if (actualArgumentConvention == NEVER_NULL) { - if (returnConvention != FAIL_ON_NULL) { - // if caller sets the null flag, return null, otherwise invoke target - methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); - - return guardWithTest( - isBlockPositionNull(methodHandle.type(), parameterIndex), - getNullShortCircuitResult(methodHandle, returnConvention), - methodHandle); - } - - MethodHandle adapter = guardWithTest( - isBlockPositionNull(getBlockValue.type(), 0), - throwTrinoNullArgumentException(getBlockValue.type()), - getBlockValue); - - return collectArguments(methodHandle, parameterIndex, adapter); + // convert ValueBlock argument to Block + if (actualArgumentConvention == VALUE_BLOCK_POSITION || actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { + adaptValueBlockArgumentToBlock(methodHandle, parameterIndex); + methodHandle = adaptValueBlockArgumentToBlock(methodHandle, parameterIndex); } - if (actualArgumentConvention == BOXED_NULLABLE) { - getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType()))); - getBlockValue = guardWithTest( - isBlockPositionNull(getBlockValue.type(), 0), - empty(getBlockValue.type()), - getBlockValue); - methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + if (actualArgumentConvention == VALUE_BLOCK_POSITION) { return methodHandle; } - if (actualArgumentConvention == NULL_FLAG) { - // long, boolean => long, Block, int - MethodHandle isNull = isBlockPositionNull(getBlockValue.type(), 0); - methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull); - - // convert get block value to be null safe - getBlockValue = guardWithTest( - isBlockPositionNull(getBlockValue.type(), 0), - empty(getBlockValue.type()), - getBlockValue); - - // long, Block, int => Block, int, Block, int - methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + return adaptParameterToBlockPosition(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); + } - int[] reorder = IntStream.range(0, methodHandle.type().parameterCount()) - .map(i -> i <= parameterIndex + 1 ? i : i - 2) - .toArray(); - MethodType newType = methodHandle.type().dropParameterTypes(parameterIndex + 2, parameterIndex + 4); - methodHandle = permuteArguments(methodHandle, newType, reorder); - return methodHandle; - } - - if (actualArgumentConvention == BLOCK_POSITION_NOT_NULL) { - if (returnConvention != FAIL_ON_NULL) { - MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); - return guardWithTest( - isBlockPositionNull(methodHandle.type(), parameterIndex), - nullReturnValue, - methodHandle); - } + // caller passes value block and position which may contain a null + if (expectedArgumentConvention == VALUE_BLOCK_POSITION) { + if (actualArgumentConvention != BLOCK_POSITION) { + methodHandle = adaptParameterToBlockPosition(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); } + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + return methodHandle; } // caller will pass boolean true in the next argument for SQL null @@ -523,7 +523,118 @@ private static MethodHandle adaptParameter( } } - throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention)); + throw unsupportedArgumentAdaptation(actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + private static MethodHandle adaptParameterToBlockPosition(MethodHandle methodHandle, int parameterIndex, Type argumentType, InvocationArgumentConvention actualArgumentConvention, InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) + { + MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); + if (actualArgumentConvention == NEVER_NULL) { + if (returnConvention != FAIL_ON_NULL) { + // if caller sets the null flag, return null, otherwise invoke target + methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + + return guardWithTest( + isBlockPositionNull(methodHandle.type(), parameterIndex), + getNullShortCircuitResult(methodHandle, returnConvention), + methodHandle); + } + + MethodHandle adapter = guardWithTest( + isBlockPositionNull(getBlockValue.type(), 0), + throwTrinoNullArgumentException(getBlockValue.type()), + getBlockValue); + + return collectArguments(methodHandle, parameterIndex, adapter); + } + + if (actualArgumentConvention == BOXED_NULLABLE) { + getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType()))); + getBlockValue = guardWithTest( + isBlockPositionNull(getBlockValue.type(), 0), + empty(getBlockValue.type()), + getBlockValue); + methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + return methodHandle; + } + + if (actualArgumentConvention == NULL_FLAG) { + // long, boolean => long, Block, int + MethodHandle isNull = isBlockPositionNull(getBlockValue.type(), 0); + methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull); + + // convert get block value to be null safe + getBlockValue = guardWithTest( + isBlockPositionNull(getBlockValue.type(), 0), + empty(getBlockValue.type()), + getBlockValue); + + // long, Block, int => Block, int, Block, int + methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + + int[] reorder = IntStream.range(0, methodHandle.type().parameterCount()) + .map(i -> i <= parameterIndex + 1 ? i : i - 2) + .toArray(); + MethodType newType = methodHandle.type().dropParameterTypes(parameterIndex + 2, parameterIndex + 4); + methodHandle = permuteArguments(methodHandle, newType, reorder); + return methodHandle; + } + + if (actualArgumentConvention == BLOCK_POSITION_NOT_NULL || actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { + if (returnConvention != FAIL_ON_NULL) { + MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); + return guardWithTest( + isBlockPositionNull(methodHandle.type(), parameterIndex), + nullReturnValue, + methodHandle); + } + } + throw unsupportedArgumentAdaptation(actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + private static MethodHandle adaptParameterToBlockPositionNotNull(MethodHandle methodHandle, int parameterIndex, Type argumentType, InvocationArgumentConvention actualArgumentConvention, InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) + { + if (actualArgumentConvention == BLOCK_POSITION || actualArgumentConvention == BLOCK_POSITION_NOT_NULL) { + return methodHandle; + } + + MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); + if (actualArgumentConvention == NEVER_NULL) { + return collectArguments(methodHandle, parameterIndex, getBlockValue); + } + if (actualArgumentConvention == BOXED_NULLABLE) { + MethodType targetType = getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType())); + return collectArguments(methodHandle, parameterIndex, explicitCastArguments(getBlockValue, targetType)); + } + if (actualArgumentConvention == NULL_FLAG) { + // actual method takes value and null flag, so change method handles to not have the flag and always pass false to the actual method + return collectArguments(insertArguments(methodHandle, parameterIndex + 1, false), parameterIndex, getBlockValue); + } + throw unsupportedArgumentAdaptation(actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + private static IllegalArgumentException unsupportedArgumentAdaptation(InvocationArgumentConvention actualArgumentConvention, InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) + { + return new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention)); + } + + private static MethodHandle adaptValueBlockArgumentToBlock(MethodHandle methodHandle, int parameterIndex) + { + // someValueBlock, position => valueBlock, position + methodHandle = explicitCastArguments(methodHandle, methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + // valueBlock, position => block, position + methodHandle = collectArguments(methodHandle, parameterIndex, GET_UNDERLYING_VALUE_BLOCK_METHOD); + // block, position => block, block, position + methodHandle = collectArguments(methodHandle, parameterIndex + 1, GET_UNDERLYING_VALUE_POSITION_METHOD); + + // block, block, position => block, position + methodHandle = permuteArguments( + methodHandle, + methodHandle.type().dropParameterTypes(parameterIndex, parameterIndex + 1), + IntStream.range(0, methodHandle.type().parameterCount()) + .map(i -> i <= parameterIndex ? i : i - 1) + .toArray()); + return methodHandle; } private static MethodHandle getBlockValue(Type argumentType, Class expectedType) @@ -647,7 +758,7 @@ private static MethodHandle isTrueNullFlag(MethodType methodType, int index) private static MethodHandle isNullArgument(MethodType methodType, int index) { // Start with Objects.isNull(Object):boolean - MethodHandle isNull = IS_NULL_METHOD; + MethodHandle isNull = OBJECT_IS_NULL_METHOD; // Cast in incoming type: isNull(T):boolean isNull = explicitCastArguments(isNull, methodType(boolean.class, methodType.parameterType(index))); // Add extra argument to match the expected method type @@ -657,40 +768,15 @@ private static MethodHandle isNullArgument(MethodType methodType, int index) private static MethodHandle isBlockPositionNull(MethodType methodType, int index) { - // Start with Objects.isNull(Object):boolean - MethodHandle isNull; - try { - isNull = lookup().findVirtual(Block.class, "isNull", methodType(boolean.class, int.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - // Add extra argument to match the expected method type - isNull = permuteArguments(isNull, methodType.changeReturnType(boolean.class), index, index + 1); - return isNull; + // Add extra argument to Block.isNull(int):boolean match the expected method type + MethodHandle blockIsNull = BLOCK_IS_NULL_METHOD.asType(BLOCK_IS_NULL_METHOD.type().changeParameterType(0, methodType.parameterType(index))); + return permuteArguments(blockIsNull, methodType.changeReturnType(boolean.class), index, index + 1); } private static MethodHandle isInOutNull(MethodType methodType, int index) { - MethodHandle isNull; - try { - isNull = lookup().findVirtual(InOut.class, "isNull", methodType(boolean.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - isNull = permuteArguments(isNull, methodType.changeReturnType(boolean.class), index); - return isNull; - } - - private static MethodHandle lookupIsNullMethod() - { - try { - return lookup().findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } + // Add extra argument to InOut.isNull(int):boolean match the expected method type + return permuteArguments(IN_OUT_IS_NULL_METHOD, methodType.changeReturnType(boolean.class), index); } private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, InvocationReturnConvention returnConvention) @@ -701,35 +787,12 @@ private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, return empty(methodHandle.type()); } - private static MethodHandle lookupAppendNullMethod() - { - try { - return lookup().findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class)) - .asType(methodType(void.class, BlockBuilder.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - } - private static MethodHandle throwTrinoNullArgumentException(MethodType type) { - MethodHandle throwException = collectArguments(throwException(type.returnType(), TrinoException.class), 0, trinoNullArgumentException()); + MethodHandle throwException = collectArguments(throwException(type.returnType(), TrinoException.class), 0, NEW_NEVER_NULL_IS_NULL_EXCEPTION); return permuteArguments(throwException, type); } - private static MethodHandle trinoNullArgumentException() - { - try { - return publicLookup().findConstructor(TrinoException.class, methodType(void.class, ErrorCodeSupplier.class, String.class)) - .bindTo(StandardErrorCode.INVALID_FUNCTION_ARGUMENT) - .bindTo("A never null argument is null"); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - } - private static boolean isWrapperType(Class type) { return type != unwrap(type); diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java index 360a8fd67cea..d9adc24d9650 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java @@ -13,14 +13,16 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.IntArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -51,7 +53,7 @@ public abstract class AbstractIntType protected AbstractIntType(TypeSignature signature) { - super(signature, long.class); + super(signature, long.class, IntArrayBlock.class); } @Override @@ -86,13 +88,7 @@ public final long getLong(Block block, int position) public final int getInt(Block block, int position) { - return block.getInt(position, 0); - } - - @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); + return readInt((IntArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -107,7 +103,7 @@ public BlockBuilder writeInt(BlockBuilder blockBuilder, int value) return ((IntArrayBlockBuilder) blockBuilder).writeInt(value); } - protected void checkValueValid(long value) + protected static void checkValueValid(long value) { if (value > Integer.MAX_VALUE) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value %d exceeds MAX_INT", value)); @@ -124,7 +120,7 @@ public final void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - writeInt(blockBuilder, block.getInt(position, 0)); + writeInt(blockBuilder, getInt(block, position)); } } @@ -161,6 +157,17 @@ public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) return new IntArrayBlockBuilder(null, positionCount); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition IntArrayBlock block, @BlockIndex int position) + { + return readInt(block, position); + } + + private static int readInt(IntArrayBlock block, int position) + { + return block.getInt(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java index 2caeb91e708b..030c2ce4fe92 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java @@ -13,13 +13,15 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -49,7 +51,7 @@ public abstract class AbstractLongType public AbstractLongType(TypeSignature signature) { - super(signature, long.class); + super(signature, long.class, LongArrayBlock.class); } @Override @@ -79,13 +81,7 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper @Override public final long getLong(Block block, int position) { - return block.getLong(position, 0); - } - - @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -144,6 +140,12 @@ public static long hash(long value) return rotateLeft(value * 0xC2B2AE3D27D4EB4FL, 31) * 0x9E3779B185EBCA87L; } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java index 2719f3a5ea0e..a7c0b9fe15e1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import java.util.List; @@ -24,11 +25,13 @@ public abstract class AbstractType { private final TypeSignature signature; private final Class javaType; + private final Class valueBlockType; - protected AbstractType(TypeSignature signature, Class javaType) + protected AbstractType(TypeSignature signature, Class javaType, Class valueBlockType) { this.signature = signature; this.javaType = javaType; + this.valueBlockType = valueBlockType; } @Override @@ -49,6 +52,12 @@ public final Class getJavaType() return javaType; } + @Override + public Class getValueBlockType() + { + return valueBlockType; + } + @Override public List getTypeParameters() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java index 5d5d839792c5..be1cf7dd70ae 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -54,7 +55,7 @@ public abstract class AbstractVariableWidthType protected AbstractVariableWidthType(TypeSignature signature, Class javaType) { - super(signature, javaType); + super(signature, javaType, VariableWidthBlock.class); } @Override @@ -72,7 +73,7 @@ public VariableWidthBlockBuilder createBlockBuilder(BlockBuilderStatus blockBuil int expectedBytes = (int) min((long) expectedEntries * expectedBytesPerEntry, maxBlockSizeInBytes); return new VariableWidthBlockBuilder( blockBuilderStatus, - expectedBytesPerEntry == 0 ? expectedEntries : Math.min(expectedEntries, maxBlockSizeInBytes / expectedBytesPerEntry), + expectedBytesPerEntry == 0 ? expectedEntries : min(expectedEntries, maxBlockSizeInBytes / expectedBytesPerEntry), expectedBytes); } @@ -89,7 +90,12 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((VariableWidthBlockBuilder) blockBuilder).buildEntry(valueBuilder -> block.writeSliceTo(position, 0, block.getSliceLength(position), valueBuilder)); + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + position = block.getUnderlyingValuePosition(position); + Slice slice = variableWidthBlock.getRawSlice(); + int offset = variableWidthBlock.getRawSliceOffset(position); + int length = variableWidthBlock.getSliceLength(position); + ((VariableWidthBlockBuilder) blockBuilder).writeEntry(slice, offset, length); } } @@ -190,21 +196,24 @@ private static void writeFlatFromStack( @ScalarOperator(READ_VALUE) private static void writeFlatFromBlock( - @BlockPosition Block block, + @BlockPosition VariableWidthBlock block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) { + Slice rawSlice = block.getRawSlice(); + int rawSliceOffset = block.getRawSliceOffset(position); int length = block.getSliceLength(position); + INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, length); if (length <= 12) { - block.writeSliceTo(position, 0, length, wrappedBuffer(fixedSizeSlice, fixedSizeOffset + Integer.BYTES, length).getOutput()); + rawSlice.getBytes(rawSliceOffset, fixedSizeSlice, fixedSizeOffset + Integer.BYTES, length); } else { INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset + Integer.BYTES + Long.BYTES, variableSizeOffset); - block.writeSliceTo(position, 0, length, wrappedBuffer(variableSizeSlice, variableSizeOffset, length).getOutput()); + rawSlice.getBytes(rawSliceOffset, variableSizeSlice, variableSizeOffset, length); } } } @@ -218,31 +227,33 @@ private static boolean equalOperator(Slice left, Slice right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); + + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); int rightLength = rightBlock.getSliceLength(rightPosition); - if (leftLength != rightLength) { - return false; - } - return leftBlock.equals(leftPosition, 0, rightBlock, rightPosition, 0, leftLength); + + return leftRawSlice.equals(leftRawSliceOffset, leftLength, rightRawSlice, rightRawSliceOffset, rightLength); } @ScalarOperator(EQUAL) - private static boolean equalOperator(Slice left, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(Slice left, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { return equalOperator(rightBlock, rightPosition, left); } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, Slice right) + private static boolean equalOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, Slice right) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); - int rightLength = right.length(); - if (leftLength != rightLength) { - return false; - } - return leftBlock.bytesEqual(leftPosition, 0, right, 0, leftLength); + + return leftRawSlice.equals(leftRawSliceOffset, leftLength, right, 0, right.length()); } @ScalarOperator(EQUAL) @@ -283,7 +294,7 @@ private static boolean equalOperator( @ScalarOperator(EQUAL) private static boolean equalOperator( - @BlockPosition Block leftBlock, + @BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, @FlatFixed byte[] rightFixedSizeSlice, @FlatFixedOffset int rightFixedSizeOffset, @@ -302,19 +313,24 @@ private static boolean equalOperator( @FlatFixed byte[] leftFixedSizeSlice, @FlatFixedOffset int leftFixedSizeOffset, @FlatVariableWidth byte[] leftVariableSizeSlice, - @BlockPosition Block rightBlock, + @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { int leftLength = (int) INT_HANDLE.get(leftFixedSizeSlice, leftFixedSizeOffset); - if (rightBlock.isNull(rightPosition) || leftLength != rightBlock.getSliceLength(rightPosition)) { + + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); + int rightLength = rightBlock.getSliceLength(rightPosition); + + if (leftLength != rightLength) { return false; } if (leftLength <= 12) { - return rightBlock.bytesEqual(rightPosition, 0, wrappedBuffer(leftFixedSizeSlice, leftFixedSizeOffset + Integer.BYTES, leftLength), 0, leftLength); + return rightRawSlice.equals(rightRawSliceOffset, rightLength, wrappedBuffer(leftFixedSizeSlice, leftFixedSizeOffset + Integer.BYTES, leftLength), 0, leftLength); } else { int leftVariableSizeOffset = (int) INT_HANDLE.get(leftFixedSizeSlice, leftFixedSizeOffset + Integer.BYTES + Long.BYTES); - return rightBlock.bytesEqual(rightPosition, 0, wrappedBuffer(leftVariableSizeSlice, leftVariableSizeOffset, leftLength), 0, leftLength); + return rightRawSlice.equals(rightRawSliceOffset, rightLength, wrappedBuffer(leftVariableSizeSlice, leftVariableSizeOffset, leftLength), 0, leftLength); } } @@ -325,9 +341,9 @@ private static long xxHash64Operator(Slice value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition VariableWidthBlock block, @BlockIndex int position) { - return block.hash(position, 0, block.getSliceLength(position)); + return XxHash64.hash(block.getRawSlice(), block.getRawSliceOffset(position), block.getSliceLength(position)); } @ScalarOperator(XX_HASH_64) @@ -354,25 +370,37 @@ private static long comparisonOperator(Slice left, Slice right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); + + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); int rightLength = rightBlock.getSliceLength(rightPosition); - return leftBlock.compareTo(leftPosition, 0, leftLength, rightBlock, rightPosition, 0, rightLength); + + return leftRawSlice.compareTo(leftRawSliceOffset, leftLength, rightRawSlice, rightRawSliceOffset, rightLength); } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, Slice right) + private static long comparisonOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, Slice right) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); - return leftBlock.bytesCompare(leftPosition, 0, leftLength, right, 0, right.length()); + + return leftRawSlice.compareTo(leftRawSliceOffset, leftLength, right, 0, right.length()); } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(Slice left, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(Slice left, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); int rightLength = rightBlock.getSliceLength(rightPosition); - return -rightBlock.bytesCompare(rightPosition, 0, rightLength, left, 0, left.length()); + + return left.compareTo(0, left.length(), rightRawSlice, rightRawSliceOffset, rightLength); } } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java index ba02ad691d89..5b904f16943c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java @@ -18,6 +18,9 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorMethodHandle; @@ -33,12 +36,11 @@ import java.util.List; import java.util.function.BiFunction; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -97,13 +99,13 @@ public class ArrayType private final Type elementType; - // this field is used in double checked locking + // this field is used in double-checked locking @SuppressWarnings("FieldAccessedSynchronizedAndUnsynchronized") private volatile TypeOperatorDeclaration operatorDeclaration; public ArrayType(Type elementType) { - super(new TypeSignature(ARRAY, TypeSignatureParameter.typeParameter(elementType.getTypeSignature())), Block.class); + super(new TypeSignature(ARRAY, TypeSignatureParameter.typeParameter(elementType.getTypeSignature())), Block.class, ArrayBlock.class); this.elementType = requireNonNull(elementType, "elementType is null"); } @@ -139,7 +141,7 @@ private static List getReadValueOperatorMethodHandles(Type MethodHandle readFlat = insertArguments(READ_FLAT, 0, elementType, elementReadOperator, elementType.getFlatFixedSize()); MethodHandle readFlatToBlock = insertArguments(READ_FLAT_TO_BLOCK, 0, elementReadOperator, elementType.getFlatFixedSize()); - MethodHandle elementWriteOperator = typeOperators.getReadValueOperator(elementType, simpleConvention(FLAT_RETURN, BLOCK_POSITION)); + MethodHandle elementWriteOperator = typeOperators.getReadValueOperator(elementType, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle writeFlatToBlock = insertArguments(WRITE_FLAT, 0, elementType, elementWriteOperator, elementType.getFlatFixedSize(), elementType.isFlatVariableWidth()); return List.of( new OperatorMethodHandle(READ_FLAT_CONVENTION, readFlat), @@ -152,7 +154,7 @@ private static List getEqualOperatorMethodHandles(TypeOper if (!elementType.isComparable()) { return emptyList(); } - MethodHandle equalOperator = typeOperators.getEqualOperator(elementType, simpleConvention(NULLABLE_RETURN, BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle equalOperator = typeOperators.getEqualOperator(elementType, simpleConvention(NULLABLE_RETURN, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(EQUAL_CONVENTION, EQUAL.bindTo(equalOperator))); } @@ -161,7 +163,7 @@ private static List getHashCodeOperatorMethodHandles(TypeO if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementHashCodeOperator = typeOperators.getHashCodeOperator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementHashCodeOperator = typeOperators.getHashCodeOperator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(HASH_CODE_CONVENTION, HASH_CODE.bindTo(elementHashCodeOperator))); } @@ -170,7 +172,7 @@ private static List getXxHash64OperatorMethodHandles(TypeO if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementHashCodeOperator = typeOperators.getXxHash64Operator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementHashCodeOperator = typeOperators.getXxHash64Operator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(HASH_CODE_CONVENTION, HASH_CODE.bindTo(elementHashCodeOperator))); } @@ -179,7 +181,7 @@ private static List getDistinctFromOperatorInvokers(TypeOp if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementDistinctFromOperator = typeOperators.getDistinctFromOperator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION)); + MethodHandle elementDistinctFromOperator = typeOperators.getDistinctFromOperator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(DISTINCT_FROM_CONVENTION, DISTINCT_FROM.bindTo(elementDistinctFromOperator))); } @@ -188,7 +190,7 @@ private static List getIndeterminateOperatorInvokers(TypeO if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementIndeterminateOperator = typeOperators.getIndeterminateOperator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementIndeterminateOperator = typeOperators.getIndeterminateOperator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(INDETERMINATE_CONVENTION, INDETERMINATE.bindTo(elementIndeterminateOperator))); } @@ -197,7 +199,7 @@ private static List getComparisonOperatorInvokers(BiFuncti if (!elementType.isOrderable()) { return emptyList(); } - MethodHandle elementComparisonOperator = comparisonOperatorFactory.apply(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementComparisonOperator = comparisonOperatorFactory.apply(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(COMPARISON_CONVENTION, COMPARISON.bindTo(elementComparisonOperator))); } @@ -228,7 +230,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block instanceof ArrayBlock) { return ((ArrayBlock) block).apply((valuesBlock, start, length) -> arrayBlockToObjectValues(session, valuesBlock, start, length), position); } - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = getObject(block, position); return arrayBlockToObjectValues(session, arrayBlock, 0, arrayBlock.getPositionCount()); } @@ -257,7 +259,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public Block getObject(Block block, int position) { - return block.getObject(position, Block.class); + return read((ArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -386,6 +388,11 @@ public String getDisplayName() return ARRAY + "(" + elementType.getDisplayName() + ")"; } + private static Block read(ArrayBlock block, int position) + { + return block.getArray(position); + } + private static Block readFlat( Type elementType, MethodHandle elementReadFlat, @@ -457,34 +464,59 @@ private static void writeFlat( private static void writeFlatElements(Type elementType, MethodHandle elementWriteFlat, int elementFixedSize, boolean elementVariableWidth, Block array, byte[] slice, int offset) throws Throwable { + array = array.getLoadedBlock(); + int positionCount = array.getPositionCount(); // variable width data starts after fixed width data // there is one extra byte per position for the null flag int writeVariableWidthOffset = offset + positionCount * (1 + elementFixedSize); - for (int index = 0; index < positionCount; index++) { - if (array.isNull(index)) { - slice[offset] = 1; - offset++; + if (array instanceof ValueBlock valuesBlock) { + for (int index = 0; index < positionCount; index++) { + writeVariableWidthOffset = writeFlatElement(elementType, elementWriteFlat, elementVariableWidth, valuesBlock, index, slice, offset, writeVariableWidthOffset); + offset += 1 + elementFixedSize; } - else { - // skip null byte - offset++; + } + else if (array instanceof RunLengthEncodedBlock rleBlock) { + ValueBlock valuesBlock = rleBlock.getValue(); + for (int index = 0; index < positionCount; index++) { + writeVariableWidthOffset = writeFlatElement(elementType, elementWriteFlat, elementVariableWidth, valuesBlock, 0, slice, offset, writeVariableWidthOffset); + offset += 1 + elementFixedSize; + } + } + else if (array instanceof DictionaryBlock dictionaryBlock) { + ValueBlock valuesBlock = dictionaryBlock.getDictionary(); + for (int position = 0; position < positionCount; position++) { + int index = dictionaryBlock.getId(position); + writeVariableWidthOffset = writeFlatElement(elementType, elementWriteFlat, elementVariableWidth, valuesBlock, index, slice, offset, writeVariableWidthOffset); + offset += 1 + elementFixedSize; + } + } + else { + throw new IllegalArgumentException("Unsupported block type: " + array.getClass().getName()); + } + } - int elementVariableSize = 0; - if (elementVariableWidth) { - elementVariableSize = elementType.getFlatVariableWidthSize(array, index); - } - elementWriteFlat.invokeExact( - array, - index, - slice, - offset, - slice, - writeVariableWidthOffset); - writeVariableWidthOffset += elementVariableSize; + private static int writeFlatElement(Type elementType, MethodHandle elementWriteFlat, boolean elementVariableWidth, ValueBlock array, int index, byte[] slice, int offset, int writeVariableWidthOffset) + throws Throwable + { + if (array.isNull(index)) { + slice[offset] = 1; + } + else { + int elementVariableSize = 0; + if (elementVariableWidth) { + elementVariableSize = elementType.getFlatVariableWidthSize(array, index); } - offset += elementFixedSize; + elementWriteFlat.invokeExact( + array, + index, + slice, + offset + 1, // skip null byte + slice, + writeVariableWidthOffset); + writeVariableWidthOffset += elementVariableSize; } + return writeVariableWidthOffset; } private static Boolean equalOperator(MethodHandle equalOperator, Block leftArray, Block rightArray) @@ -494,13 +526,21 @@ private static Boolean equalOperator(MethodHandle equalOperator, Block leftArray return false; } + leftArray = leftArray.getLoadedBlock(); + rightArray = rightArray.getLoadedBlock(); + + ValueBlock leftValues = leftArray.getUnderlyingValueBlock(); + ValueBlock rightValues = rightArray.getUnderlyingValueBlock(); + boolean unknown = false; for (int position = 0; position < leftArray.getPositionCount(); position++) { - if (leftArray.isNull(position) || rightArray.isNull(position)) { + int leftIndex = leftArray.getUnderlyingValuePosition(position); + int rightIndex = rightArray.getUnderlyingValuePosition(position); + if (leftValues.isNull(leftIndex) || rightValues.isNull(rightIndex)) { unknown = true; continue; } - Boolean result = (Boolean) equalOperator.invokeExact(leftArray, position, rightArray, position); + Boolean result = (Boolean) equalOperator.invokeExact(leftValues, leftIndex, rightValues, rightIndex); if (result == null) { unknown = true; } @@ -515,15 +555,43 @@ else if (!result) { return true; } - private static long hashOperator(MethodHandle hashOperator, Block block) + private static long hashOperator(MethodHandle hashOperator, Block array) throws Throwable { - long hash = 0; - for (int position = 0; position < block.getPositionCount(); position++) { - long elementHash = block.isNull(position) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(block, position); - hash = 31 * hash + elementHash; + array = array.getLoadedBlock(); + + if (array instanceof ValueBlock valuesBlock) { + long hash = 0; + for (int index = 0; index < valuesBlock.getPositionCount(); index++) { + long elementHash = valuesBlock.isNull(index) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(valuesBlock, index); + hash = 31 * hash + elementHash; + } + return hash; + } + + if (array instanceof RunLengthEncodedBlock rleBlock) { + ValueBlock valuesBlock = rleBlock.getValue(); + long elementHash = valuesBlock.isNull(0) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(valuesBlock, 0); + + long hash = 0; + for (int position = 0; position < valuesBlock.getPositionCount(); position++) { + hash = 31 * hash + elementHash; + } + return hash; } - return hash; + + if (array instanceof DictionaryBlock dictionaryBlock) { + ValueBlock valuesBlock = dictionaryBlock.getDictionary(); + long hash = 0; + for (int position = 0; position < valuesBlock.getPositionCount(); position++) { + int index = dictionaryBlock.getId(position); + long elementHash = valuesBlock.isNull(position) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(valuesBlock, index); + hash = 31 * hash + elementHash; + } + return hash; + } + + throw new IllegalArgumentException("Unsupported block type: " + array.getClass().getName()); } private static boolean distinctFromOperator(MethodHandle distinctFromOperator, Block leftArray, Block rightArray) @@ -539,8 +607,26 @@ private static boolean distinctFromOperator(MethodHandle distinctFromOperator, B return true; } + leftArray = leftArray.getLoadedBlock(); + rightArray = rightArray.getLoadedBlock(); + + ValueBlock leftValues = leftArray.getUnderlyingValueBlock(); + ValueBlock rightValues = rightArray.getUnderlyingValueBlock(); + for (int position = 0; position < leftArray.getPositionCount(); position++) { - boolean result = (boolean) distinctFromOperator.invokeExact(leftArray, position, rightArray, position); + int leftIndex = leftArray.getUnderlyingValuePosition(position); + int rightIndex = rightArray.getUnderlyingValuePosition(position); + + boolean leftValueIsNull = leftValues.isNull(leftIndex); + boolean rightValueIsNull = rightValues.isNull(rightIndex); + if (leftValueIsNull != rightValueIsNull) { + return true; + } + if (leftValueIsNull) { + continue; + } + + boolean result = (boolean) distinctFromOperator.invokeExact(leftValues, leftIndex, rightValues, rightIndex); if (result) { return true; } @@ -549,33 +635,73 @@ private static boolean distinctFromOperator(MethodHandle distinctFromOperator, B return false; } - private static boolean indeterminateOperator(MethodHandle elementIndeterminateFunction, Block block, boolean isNull) + private static boolean indeterminateOperator(MethodHandle elementIndeterminateFunction, Block array, boolean isNull) throws Throwable { if (isNull) { return true; } - for (int position = 0; position < block.getPositionCount(); position++) { - if (block.isNull(position)) { + array = array.getLoadedBlock(); + + if (array instanceof ValueBlock valuesBlock) { + for (int index = 0; index < valuesBlock.getPositionCount(); index++) { + if (valuesBlock.isNull(index)) { + return true; + } + if ((boolean) elementIndeterminateFunction.invoke(valuesBlock, index)) { + return true; + } + } + return false; + } + + if (array instanceof RunLengthEncodedBlock rleBlock) { + ValueBlock valuesBlock = rleBlock.getValue(); + if (valuesBlock.isNull(0)) { return true; } - if ((boolean) elementIndeterminateFunction.invoke(block, position)) { + if ((boolean) elementIndeterminateFunction.invoke(valuesBlock, 0)) { return true; } + return false; } - return false; + + if (array instanceof DictionaryBlock dictionaryBlock) { + ValueBlock valuesBlock = dictionaryBlock.getDictionary(); + for (int position = 0; position < valuesBlock.getPositionCount(); position++) { + int index = dictionaryBlock.getId(position); + if (valuesBlock.isNull(index)) { + return true; + } + if ((boolean) elementIndeterminateFunction.invoke(valuesBlock, index)) { + return true; + } + } + return false; + } + + throw new IllegalArgumentException("Unsupported block type: " + array.getClass().getName()); } private static long comparisonOperator(MethodHandle comparisonOperator, Block leftArray, Block rightArray) throws Throwable { + leftArray = leftArray.getLoadedBlock(); + rightArray = rightArray.getLoadedBlock(); + + ValueBlock leftValues = leftArray.getUnderlyingValueBlock(); + ValueBlock rightValues = rightArray.getUnderlyingValueBlock(); + int len = Math.min(leftArray.getPositionCount(), rightArray.getPositionCount()); for (int position = 0; position < len; position++) { checkElementNotNull(leftArray.isNull(position), ARRAY_NULL_ELEMENT_MSG); checkElementNotNull(rightArray.isNull(position), ARRAY_NULL_ELEMENT_MSG); - long result = (long) comparisonOperator.invokeExact(leftArray, position, rightArray, position); + int leftIndex = leftArray.getUnderlyingValuePosition(position); + int rightIndex = rightArray.getUnderlyingValuePosition(position); + + long result = (long) comparisonOperator.invokeExact(leftValues, leftIndex, rightValues, rightIndex); if (result != 0) { return result; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java b/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java index 8c6bd5dc61ac..7fb7b46fed51 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java @@ -37,7 +37,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getLong(position, 0); + return getLong(block, position); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java b/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java index cab45b0f1599..d2195f2c1619 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java @@ -21,6 +21,8 @@ import io.trino.spi.block.ByteArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -67,7 +69,7 @@ public static Block createBlockForSingleNonNullValue(boolean value) private BooleanType() { - super(new TypeSignature(StandardTypes.BOOLEAN), boolean.class); + super(new TypeSignature(StandardTypes.BOOLEAN), boolean.class, ByteArrayBlock.class); } @Override @@ -128,7 +130,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getByte(position, 0) != 0; + return getBoolean(block, position); } @Override @@ -138,14 +140,14 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((ByteArrayBlockBuilder) blockBuilder).writeByte(block.getByte(position, 0)); + ((ByteArrayBlockBuilder) blockBuilder).writeByte(getBoolean(block, position) ? (byte) 1 : 0); } } @Override public boolean getBoolean(Block block, int position) { - return block.getByte(position, 0) != 0; + return read((ByteArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -172,6 +174,12 @@ public int hashCode() return getClass().hashCode(); } + @ScalarOperator(READ_VALUE) + private static boolean read(@BlockPosition ByteArrayBlock block, @BlockIndex int position) + { + return block.getByte(position) != 0; + } + @ScalarOperator(READ_VALUE) private static boolean readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java b/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java index ed48d409e2b0..b9e1967a848a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.ScalarOperator; @@ -125,7 +126,7 @@ public Optional getRange() if (!cachedRangePresent) { if (length > 100) { // The max/min values may be materialized in the plan, so we don't want them to be too large. - // Range comparison against large values are usually nonsensical, too, so no need to support them + // Range comparison against large values is usually nonsensical, too, so no need to support them // beyond a certain size. They specific choice above is arbitrary and can be adjusted if needed. range = Optional.empty(); } @@ -158,7 +159,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); + Slice slice = getSlice(block, position); if (slice.length() > 0) { if (countCodePoints(slice) > length) { throw new IllegalArgumentException(format("Character count exceeds length limit %s: %s", length, sliceRepresentation(slice))); @@ -182,7 +183,9 @@ public VariableWidthBlockBuilder createBlockBuilder(BlockBuilderStatus blockBuil @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } public void writeString(BlockBuilder blockBuilder, String value) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java index be74bb309a49..5dad5193338f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java @@ -23,7 +23,7 @@ // // Note: when dealing with a java.sql.Date it is important to remember that the value is stored // as the number of milliseconds from 1970-01-01T00:00:00 in UTC but time must be midnight in -// the local time zone. This mean when converting between a java.sql.Date and this +// the local time zone. This means when converting between a java.sql.Date and this // type, the time zone offset must be added or removed to keep the time at midnight in UTC. // public final class DateType @@ -43,7 +43,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - int days = block.getInt(position, 0); + int days = getInt(block, position); return new SqlDate(days); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java index efaaba6a2f51..828780d5ba16 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import java.util.List; @@ -60,9 +61,9 @@ public static DecimalType createDecimalType() private final int precision; private final int scale; - DecimalType(int precision, int scale, Class javaType) + DecimalType(int precision, int scale, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.DECIMAL, buildTypeParameters(precision, scale)), javaType); + super(new TypeSignature(StandardTypes.DECIMAL, buildTypeParameters(precision, scale)), javaType, valueBlockType); this.precision = precision; this.scale = scale; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java index 25e506abb701..4a9175875e70 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java @@ -17,9 +17,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +59,7 @@ public final class DoubleType private DoubleType() { - super(new TypeSignature(StandardTypes.DOUBLE), double.class); + super(new TypeSignature(StandardTypes.DOUBLE), double.class, LongArrayBlock.class); } @Override @@ -89,7 +92,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - return longBitsToDouble(block.getLong(position, 0)); + return getDouble(block, position); } @Override @@ -99,14 +102,16 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((LongArrayBlockBuilder) blockBuilder).writeLong(block.getLong(position, 0)); + LongArrayBlock valueBlock = (LongArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((LongArrayBlockBuilder) blockBuilder).writeLong(valueBlock.getLong(valuePosition)); } } @Override public double getDouble(Block block, int position) { - return longBitsToDouble(block.getLong(position, 0)); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -169,6 +174,12 @@ public Optional getRange() return Optional.empty(); } + @ScalarOperator(READ_VALUE) + private static double read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return longBitsToDouble(block.getLong(position)); + } + @ScalarOperator(READ_VALUE) private static double readFlat( @FlatFixed byte[] fixedSizeSlice, @@ -189,6 +200,7 @@ private static void writeFlat( DOUBLE_HANDLE.set(fixedSizeSlice, fixedSizeOffset, value); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(EQUAL) private static boolean equalOperator(double left, double right) { @@ -213,6 +225,7 @@ public static long xxHash64(double value) return XxHash64.hash(doubleToLongBits(value)); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(IS_DISTINCT_FROM) private static boolean distinctFromOperator(double left, @IsNull boolean leftNull, double right, @IsNull boolean rightNull) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java b/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java index a128184e4a75..15be6c253302 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -37,7 +38,9 @@ public HyperLogLogType() @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -59,6 +62,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java b/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java index 0726dabb941d..7d9da63bbaba 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java @@ -37,7 +37,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getInt(position, 0); + return getInt(block, position); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java index 42b31d537f9c..1519ecacecfd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -49,7 +50,7 @@ final class LongDecimalType LongDecimalType(int precision, int scale) { - super(precision, scale, Int128.class); + super(precision, scale, Int128.class, Int128ArrayBlock.class); checkArgument(Decimals.MAX_SHORT_PRECISION < precision && precision <= Decimals.MAX_PRECISION, "Invalid precision: %s", precision); checkArgument(0 <= scale && scale <= precision, "Invalid scale for precision %s: %s", precision, scale); } @@ -99,7 +100,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - Int128 value = (Int128) getObject(block, position); + Int128 value = getObject(block, position); BigInteger unscaledValue = value.toBigInteger(); return new SqlDecimal(unscaledValue, getPrecision(), getScale()); } @@ -111,9 +112,9 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((Int128ArrayBlockBuilder) blockBuilder).writeInt128( - block.getLong(position, 0), - block.getLong(position, SIZE_OF_LONG)); + Int128ArrayBlock valueBlock = (Int128ArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((Int128ArrayBlockBuilder) blockBuilder).writeInt128(valueBlock.getInt128High(valuePosition), valueBlock.getInt128Low(valuePosition)); } } @@ -125,11 +126,9 @@ public void writeObject(BlockBuilder blockBuilder, Object value) } @Override - public Object getObject(Block block, int position) + public Int128 getObject(Block block, int position) { - return Int128.valueOf( - block.getLong(position, 0), - block.getLong(position, SIZE_OF_LONG)); + return read((Int128ArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -138,6 +137,12 @@ public int getFlatFixedSize() return INT128_BYTES; } + @ScalarOperator(READ_VALUE) + private static Int128 read(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) + { + return block.getInt128(position); + } + @ScalarOperator(READ_VALUE) private static Int128 readFlat( @FlatFixed byte[] fixedSizeSlice, @@ -175,15 +180,15 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockToFlat( - @BlockPosition Block block, + @BlockPosition Int128ArrayBlock block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] unusedVariableSizeSlice, int unusedVariableSizeOffset) { - LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset, block.getLong(position, 0)); - LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset + SIZE_OF_LONG, block.getLong(position, SIZE_OF_LONG)); + LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset, block.getInt128High(position)); + LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset + SIZE_OF_LONG, block.getInt128Low(position)); } @ScalarOperator(EQUAL) @@ -193,10 +198,10 @@ private static boolean equalOperator(Int128 left, Int128 right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { - return leftBlock.getLong(leftPosition, 0) == rightBlock.getLong(rightPosition, 0) && - leftBlock.getLong(leftPosition, SIZE_OF_LONG) == rightBlock.getLong(rightPosition, SIZE_OF_LONG); + return leftBlock.getInt128High(leftPosition) == rightBlock.getInt128High(rightPosition) && + leftBlock.getInt128Low(leftPosition) == rightBlock.getInt128Low(rightPosition); } @ScalarOperator(XX_HASH_64) @@ -206,9 +211,9 @@ private static long xxHash64Operator(Int128 value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) { - return xxHash64(block.getLong(position, 0), block.getLong(position, SIZE_OF_LONG)); + return xxHash64(block.getInt128High(position), block.getInt128Low(position)); } private static long xxHash64(long low, long high) @@ -223,12 +228,12 @@ private static long comparisonOperator(Int128 left, Int128 right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return Int128.compare( - leftBlock.getLong(leftPosition, 0), - leftBlock.getLong(leftPosition, SIZE_OF_LONG), - rightBlock.getLong(rightPosition, 0), - rightBlock.getLong(rightPosition, SIZE_OF_LONG)); + leftBlock.getInt128High(leftPosition), + leftBlock.getInt128Low(leftPosition), + rightBlock.getInt128High(rightPosition), + rightBlock.getInt128Low(rightPosition)); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java index 57e6573fdce5..9fc12b124a0f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Fixed12BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -53,7 +54,7 @@ final class LongTimeWithTimeZoneType public LongTimeWithTimeZoneType(int precision) { - super(precision, LongTimeWithTimeZone.class); + super(precision, LongTimeWithTimeZone.class, Fixed12Block.class); if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION)); @@ -106,14 +107,18 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - write(blockBuilder, getPicos(block, position), getOffsetMinutes(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + write(blockBuilder, getPicos(valueBlock, valuePosition), getOffsetMinutes(valueBlock, valuePosition)); } } @Override - public Object getObject(Block block, int position) + public LongTimeWithTimeZone getObject(Block block, int position) { - return new LongTimeWithTimeZone(getPicos(block, position), getOffsetMinutes(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return new LongTimeWithTimeZone(getPicos(valueBlock, valuePosition), getOffsetMinutes(valueBlock, valuePosition)); } @Override @@ -135,7 +140,9 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return SqlTimeWithTimeZone.newInstance(getPrecision(), getPicos(block, position), getOffsetMinutes(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return SqlTimeWithTimeZone.newInstance(getPrecision(), getPicos(valueBlock, valuePosition), getOffsetMinutes(valueBlock, valuePosition)); } @Override @@ -144,14 +151,14 @@ public int getFlatFixedSize() return Long.BYTES + Integer.BYTES; } - private static long getPicos(Block block, int position) + private static long getPicos(Fixed12Block block, int position) { - return block.getLong(position, 0); + return block.getFixed12First(position); } - private static int getOffsetMinutes(Block block, int position) + private static int getOffsetMinutes(Fixed12Block block, int position) { - return block.getInt(position, SIZE_OF_LONG); + return block.getFixed12Second(position); } @ScalarOperator(READ_VALUE) @@ -191,7 +198,7 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockFlat( - @BlockPosition Block block, + @BlockPosition Fixed12Block block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, @@ -213,7 +220,7 @@ private static boolean equalOperator(LongTimeWithTimeZone left, LongTimeWithTime } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return equal( getPicos(leftBlock, leftPosition), @@ -234,7 +241,7 @@ private static long hashCodeOperator(LongTimeWithTimeZone value) } @ScalarOperator(HASH_CODE) - private static long hashCodeOperator(@BlockPosition Block block, @BlockIndex int position) + private static long hashCodeOperator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return hashCodeOperator(getPicos(block, position), getOffsetMinutes(block, position)); } @@ -251,7 +258,7 @@ private static long xxHash64Operator(LongTimeWithTimeZone value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return xxHash64(getPicos(block, position), getOffsetMinutes(block, position)); } @@ -272,7 +279,7 @@ private static long comparisonOperator(LongTimeWithTimeZone left, LongTimeWithTi } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return comparison( getPicos(leftBlock, leftPosition), @@ -297,7 +304,7 @@ private static boolean lessThanOperator(LongTimeWithTimeZone left, LongTimeWithT } @ScalarOperator(LESS_THAN) - private static boolean lessThanOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThan( getPicos(leftBlock, leftPosition), @@ -322,7 +329,7 @@ private static boolean lessThanOrEqualOperator(LongTimeWithTimeZone left, LongTi } @ScalarOperator(LESS_THAN_OR_EQUAL) - private static boolean lessThanOrEqualOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOrEqualOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThanOrEqual( getPicos(leftBlock, leftPosition), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java index c7c5d0e6426c..0e13a7b7331e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Fixed12BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -51,7 +52,7 @@ * in the first long and the fractional increment in the remaining integer, as * a number of picoseconds additional to the epoch microsecond. */ -class LongTimestampType +final class LongTimestampType extends TimestampType { private static final TypeOperatorDeclaration TYPE_OPERATOR_DECLARATION = extractOperatorDeclaration(LongTimestampType.class, lookup(), LongTimestamp.class); @@ -61,13 +62,13 @@ class LongTimestampType public LongTimestampType(int precision) { - super(precision, LongTimestamp.class); + super(precision, LongTimestamp.class, Fixed12Block.class); if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION)); } - // ShortTimestampType instances are created eagerly and shared so it's OK to precompute some things. + // ShortTimestampType instances are created eagerly and shared, so it's OK to precompute some things. int picosOfMicroMax = toIntExact(PICOSECONDS_PER_MICROSECOND - rescale(1, 0, 12 - getPrecision())); range = new Range(new LongTimestamp(Long.MIN_VALUE, 0), new LongTimestamp(Long.MAX_VALUE, picosOfMicroMax)); } @@ -118,16 +119,18 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((Fixed12BlockBuilder) blockBuilder).writeFixed12( - getEpochMicros(block, position), - getFraction(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((Fixed12BlockBuilder) blockBuilder).writeFixed12(getEpochMicros(valueBlock, valuePosition), getFraction(valueBlock, valuePosition)); } } @Override public Object getObject(Block block, int position) { - return new LongTimestamp(getEpochMicros(block, position), getFraction(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return new LongTimestamp(getEpochMicros(valueBlock, valuePosition), getFraction(valueBlock, valuePosition)); } @Override @@ -149,10 +152,9 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long epochMicros = getEpochMicros(block, position); - int fraction = getFraction(block, position); - - return SqlTimestamp.newInstance(getPrecision(), epochMicros, fraction); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return SqlTimestamp.newInstance(getPrecision(), getEpochMicros(valueBlock, valuePosition), getFraction(valueBlock, valuePosition)); } @Override @@ -161,14 +163,14 @@ public int getFlatFixedSize() return Long.BYTES + Integer.BYTES; } - private static long getEpochMicros(Block block, int position) + private static long getEpochMicros(Fixed12Block block, int position) { - return block.getLong(position, 0); + return block.getFixed12First(position); } - private static int getFraction(Block block, int position) + private static int getFraction(Fixed12Block block, int position) { - return block.getInt(position, SIZE_OF_LONG); + return block.getFixed12Second(position); } @Override @@ -214,7 +216,7 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockFlat( - @BlockPosition Block block, + @BlockPosition Fixed12Block block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, @@ -236,7 +238,7 @@ private static boolean equalOperator(LongTimestamp left, LongTimestamp right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return equal( getEpochMicros(leftBlock, leftPosition), @@ -257,7 +259,7 @@ private static long xxHash64Operator(LongTimestamp value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return xxHash64( getEpochMicros(block, position), @@ -276,7 +278,7 @@ private static long comparisonOperator(LongTimestamp left, LongTimestamp right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return comparison( getEpochMicros(leftBlock, leftPosition), @@ -301,7 +303,7 @@ private static boolean lessThanOperator(LongTimestamp left, LongTimestamp right) } @ScalarOperator(LESS_THAN) - private static boolean lessThanOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThan( getEpochMicros(leftBlock, leftPosition), @@ -323,7 +325,7 @@ private static boolean lessThanOrEqualOperator(LongTimestamp left, LongTimestamp } @ScalarOperator(LESS_THAN_OR_EQUAL) - private static boolean lessThanOrEqualOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOrEqualOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThanOrEqual( getEpochMicros(leftBlock, leftPosition), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java index 1e131f13b7ce..cd58e49d993c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Fixed12BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -64,7 +65,7 @@ final class LongTimestampWithTimeZoneType public LongTimestampWithTimeZoneType(int precision) { - super(precision, LongTimestampWithTimeZone.class); + super(precision, LongTimestampWithTimeZone.class, Fixed12Block.class); if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION)); @@ -117,15 +118,19 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - write(blockBuilder, getPackedEpochMillis(block, position), getPicosOfMilli(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + write(blockBuilder, getPackedEpochMillis(valueBlock, valuePosition), getPicosOfMilli(valueBlock, valuePosition)); } } @Override public Object getObject(Block block, int position) { - long packedEpochMillis = getPackedEpochMillis(block, position); - int picosOfMilli = getPicosOfMilli(block, position); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + long packedEpochMillis = getPackedEpochMillis(valueBlock, valuePosition); + int picosOfMilli = getPicosOfMilli(valueBlock, valuePosition); return LongTimestampWithTimeZone.fromEpochMillisAndFraction(unpackMillisUtc(packedEpochMillis), picosOfMilli, unpackZoneKey(packedEpochMillis)); } @@ -152,8 +157,10 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long packedEpochMillis = getPackedEpochMillis(block, position); - int picosOfMilli = getPicosOfMilli(block, position); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + long packedEpochMillis = getPackedEpochMillis(valueBlock, valuePosition); + int picosOfMilli = getPicosOfMilli(valueBlock, valuePosition); return SqlTimestampWithTimeZone.newInstance(getPrecision(), unpackMillisUtc(packedEpochMillis), picosOfMilli, unpackZoneKey(packedEpochMillis)); } @@ -200,19 +207,19 @@ public Optional getNextValue(Object value) return Optional.of(LongTimestampWithTimeZone.fromEpochMillisAndFraction(epochMillis, picosOfMilli, UTC_KEY)); } - private static long getPackedEpochMillis(Block block, int position) + private static long getPackedEpochMillis(Fixed12Block block, int position) { - return block.getLong(position, 0); + return block.getFixed12First(position); } - private static long getEpochMillis(Block block, int position) + private static long getEpochMillis(Fixed12Block block, int position) { return unpackMillisUtc(getPackedEpochMillis(block, position)); } - private static int getPicosOfMilli(Block block, int position) + private static int getPicosOfMilli(Fixed12Block block, int position) { - return block.getInt(position, SIZE_OF_LONG); + return block.getFixed12Second(position); } @ScalarOperator(READ_VALUE) @@ -252,7 +259,7 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockFlat( - @BlockPosition Block block, + @BlockPosition Fixed12Block block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, @@ -274,7 +281,7 @@ private static boolean equalOperator(LongTimestampWithTimeZone left, LongTimesta } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return equal( getEpochMillis(leftBlock, leftPosition), @@ -296,7 +303,7 @@ private static long xxHash64Operator(LongTimestampWithTimeZone value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return xxHash64( getEpochMillis(block, position), @@ -315,7 +322,7 @@ private static long comparisonOperator(LongTimestampWithTimeZone left, LongTimes } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return comparison( getEpochMillis(leftBlock, leftPosition), @@ -340,7 +347,7 @@ private static boolean lessThanOperator(LongTimestampWithTimeZone left, LongTime } @ScalarOperator(LESS_THAN) - private static boolean lessThanOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThan( getEpochMillis(leftBlock, leftPosition), @@ -362,7 +369,7 @@ private static boolean lessThanOrEqualOperator(LongTimestampWithTimeZone left, L } @ScalarOperator(LESS_THAN_OR_EQUAL) - private static boolean lessThanOrEqualOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOrEqualOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThanOrEqual( getEpochMillis(leftBlock, leftPosition), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java b/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java index 2e2283a13746..049a0594db92 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java @@ -111,7 +111,7 @@ public class MapType private final MethodHandle keyBlockNativeEqual; private final MethodHandle keyBlockEqual; - // this field is used in double checked locking + // this field is used in double-checked locking @SuppressWarnings("FieldAccessedSynchronizedAndUnsynchronized") private volatile TypeOperatorDeclaration typeOperatorDeclaration; @@ -122,7 +122,8 @@ public MapType(Type keyType, Type valueType, TypeOperators typeOperators) StandardTypes.MAP, TypeSignatureParameter.typeParameter(keyType.getTypeSignature()), TypeSignatureParameter.typeParameter(valueType.getTypeSignature())), - SqlMap.class); + SqlMap.class, + MapBlock.class); if (!keyType.isComparable()) { throw new IllegalArgumentException(format("key type must be comparable, got %s", keyType)); } @@ -291,7 +292,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); @@ -318,7 +319,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public SqlMap getObject(Block block, int position) { - return block.getObject(position, SqlMap.class); + return read((MapBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -463,7 +464,7 @@ public String getDisplayName() return "map(" + keyType.getDisplayName() + ", " + valueType.getDisplayName() + ")"; } - public Block createBlockFromKeyValue(Optional mapIsNull, int[] offsets, Block keyBlock, Block valueBlock) + public MapBlock createBlockFromKeyValue(Optional mapIsNull, int[] offsets, Block keyBlock, Block valueBlock) { return MapBlock.fromKeyValueBlock( mapIsNull, @@ -544,6 +545,11 @@ private static long invokeHashOperator(MethodHandle hashOperator, Block block, i return (long) hashOperator.invokeExact((Block) block, position); } + private static SqlMap read(MapBlock block, int position) + { + return block.getMap(position); + } + private static SqlMap readFlat( MapType mapType, MethodHandle keyReadOperator, @@ -825,7 +831,7 @@ private static boolean indeterminate(MethodHandle valueIndeterminateFunction, Sq Block rawValueBlock = sqlMap.getRawValueBlock(); for (int i = 0; i < sqlMap.getSize(); i++) { - // since maps are not allowed to have indeterminate keys we only check values here + // since maps are not allowed to have indeterminate keys, we only check values here if (rawValueBlock.isNull(rawOffset + i)) { return true; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java b/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java index 043a54ed99a3..7f3bb4cb785e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -39,7 +40,9 @@ public QuantileDigestType(Type valueType) @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -61,7 +64,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } public Type getValueType() diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java b/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java index 50385018b295..da26b556cefd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java @@ -76,7 +76,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position public float getFloat(Block block, int position) { - return intBitsToFloat(block.getInt(position, 0)); + return intBitsToFloat(getInt(block, position)); } @Override @@ -137,6 +137,7 @@ private static void writeFlat( INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, (int) value); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(EQUAL) private static boolean equalOperator(long left, long right) { @@ -163,6 +164,7 @@ private static long xxHash64Operator(long value) return XxHash64.hash(floatToIntBits(realValue)); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(IS_DISTINCT_FROM) private static boolean distinctFromOperator(long left, @IsNull boolean leftNull, long right, @IsNull boolean rightNull) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java b/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java index 65ccdfa9cab1..a5115075e4d5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java @@ -18,6 +18,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.RowBlock; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; import io.trino.spi.connector.ConnectorSession; @@ -127,7 +128,7 @@ public class RowType private RowType(TypeSignature typeSignature, List originalFields) { - super(typeSignature, SqlRow.class); + super(typeSignature, SqlRow.class, RowBlock.class); this.fields = List.copyOf(originalFields); this.fieldTypes = fields.stream() @@ -268,7 +269,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public SqlRow getObject(Block block, int position) { - return block.getObject(position, SqlRow.class); + return read((RowBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -425,6 +426,11 @@ private List getReadValueOperatorMethodHandles(TypeOperato new OperatorMethodHandle(WRITE_FLAT_CONVENTION, writeFlat)); } + private static SqlRow read(RowBlock block, int position) + { + return block.getRow(position); + } + private static SqlRow megamorphicReadFlat( RowType rowType, List fieldReadFlatMethods, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java index 0460e018420a..adda87f36d0c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java @@ -17,9 +17,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -69,8 +72,8 @@ static ShortDecimalType getInstance(int precision, int scale) private ShortDecimalType(int precision, int scale) { - super(precision, scale, long.class); - checkArgument(0 < precision && precision <= Decimals.MAX_SHORT_PRECISION, "Invalid precision: %s", precision); + super(precision, scale, long.class, LongArrayBlock.class); + checkArgument(0 < precision && precision <= MAX_SHORT_PRECISION, "Invalid precision: %s", precision); checkArgument(0 <= scale && scale <= precision, "Invalid scale for precision %s: %s", precision, scale); } @@ -119,8 +122,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - long unscaledValue = block.getLong(position, 0); - return new SqlDecimal(BigInteger.valueOf(unscaledValue), getPrecision(), getScale()); + return new SqlDecimal(BigInteger.valueOf(getLong(block, position)), getPrecision(), getScale()); } @Override @@ -130,14 +132,14 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - writeLong(blockBuilder, block.getLong(position, 0)); + writeLong(blockBuilder, getLong(block, position)); } } @Override public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -158,6 +160,12 @@ public Optional> getDiscreteValues(Range range) return Optional.of(LongStream.rangeClosed((long) range.getMin(), (long) range.getMax()).boxed()); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java index 2297ae355180..d679a0d37357 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java @@ -13,14 +13,16 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +58,7 @@ final class ShortTimeWithTimeZoneType public ShortTimeWithTimeZoneType(int precision) { - super(precision, long.class); + super(precision, long.class, LongArrayBlock.class); if (precision < 0 || precision > MAX_SHORT_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION)); @@ -70,42 +72,36 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper } @Override - public final int getFixedSize() + public int getFixedSize() { return Long.BYTES; } @Override - public final long getLong(Block block, int position) + public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); - } - - @Override - public final void writeLong(BlockBuilder blockBuilder, long value) + public void writeLong(BlockBuilder blockBuilder, long value) { ((LongArrayBlockBuilder) blockBuilder).writeLong(value); } @Override - public final void appendTo(Block block, int position, BlockBuilder blockBuilder) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { if (block.isNull(position)) { blockBuilder.appendNull(); } else { - writeLong(blockBuilder, block.getLong(position, 0)); + writeLong(blockBuilder, getLong(block, position)); } } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { int maxBlockSizeInBytes; if (blockBuilderStatus == null) { @@ -120,13 +116,13 @@ public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStat } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { return createBlockBuilder(blockBuilderStatus, expectedEntries, Long.BYTES); } @Override - public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) + public BlockBuilder createFixedSizeBlockBuilder(int positionCount) { return new LongArrayBlockBuilder(null, positionCount); } @@ -138,7 +134,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long value = block.getLong(position, 0); + long value = getLong(block, position); return SqlTimeWithTimeZone.newInstance(getPrecision(), unpackTimeNanos(value) * PICOSECONDS_PER_NANOSECOND, unpackOffsetMinutes(value)); } @@ -148,6 +144,12 @@ public int getFlatFixedSize() return Long.BYTES; } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java index f3ea087dc1d4..6a86b1febd78 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java @@ -17,9 +17,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -48,7 +51,7 @@ * The value is encoded as microseconds from the 1970-01-01 00:00:00 epoch and is to be interpreted as * local date time without regards to any time zone. */ -class ShortTimestampType +final class ShortTimestampType extends TimestampType { private static final TypeOperatorDeclaration TYPE_OPERATOR_DECLARATION = extractOperatorDeclaration(ShortTimestampType.class, lookup(), long.class); @@ -57,13 +60,13 @@ class ShortTimestampType public ShortTimestampType(int precision) { - super(precision, long.class); + super(precision, long.class, LongArrayBlock.class); if (precision < 0 || precision > MAX_SHORT_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION)); } - // ShortTimestampType instances are created eagerly and shared so it's OK to precompute some things. + // ShortTimestampType instances are created eagerly and shared, so it's OK to precompute some things. if (getPrecision() == MAX_SHORT_PRECISION) { range = new Range(Long.MIN_VALUE, Long.MAX_VALUE); } @@ -80,25 +83,25 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper } @Override - public final int getFixedSize() + public int getFixedSize() { return Long.BYTES; } @Override - public final long getLong(Block block, int position) + public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override - public final void writeLong(BlockBuilder blockBuilder, long value) + public void writeLong(BlockBuilder blockBuilder, long value) { ((LongArrayBlockBuilder) blockBuilder).writeLong(value); } @Override - public final void appendTo(Block block, int position, BlockBuilder blockBuilder) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { if (block.isNull(position)) { blockBuilder.appendNull(); @@ -109,7 +112,7 @@ public final void appendTo(Block block, int position, BlockBuilder blockBuilder) } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { int maxBlockSizeInBytes; if (blockBuilderStatus == null) { @@ -124,13 +127,13 @@ public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStat } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { return createBlockBuilder(blockBuilderStatus, expectedEntries, Long.BYTES); } @Override - public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) + public BlockBuilder createFixedSizeBlockBuilder(int positionCount) { return new LongArrayBlockBuilder(null, positionCount); } @@ -176,6 +179,12 @@ public Optional getNextValue(Object value) return Optional.of((long) value + rescale(1_000_000, getPrecision(), 0)); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java index 42db05208b69..401ba757e344 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java @@ -13,14 +13,16 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +58,7 @@ final class ShortTimestampWithTimeZoneType public ShortTimestampWithTimeZoneType(int precision) { - super(precision, long.class); + super(precision, long.class, LongArrayBlock.class); if (precision < 0 || precision > MAX_SHORT_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION)); @@ -70,42 +72,36 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper } @Override - public final int getFixedSize() + public int getFixedSize() { return Long.BYTES; } @Override - public final long getLong(Block block, int position) + public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); - } - - @Override - public final void writeLong(BlockBuilder blockBuilder, long value) + public void writeLong(BlockBuilder blockBuilder, long value) { ((LongArrayBlockBuilder) blockBuilder).writeLong(value); } @Override - public final void appendTo(Block block, int position, BlockBuilder blockBuilder) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { if (block.isNull(position)) { blockBuilder.appendNull(); } else { - writeLong(blockBuilder, block.getLong(position, 0)); + writeLong(blockBuilder, getLong(block, position)); } } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { int maxBlockSizeInBytes; if (blockBuilderStatus == null) { @@ -120,13 +116,13 @@ public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStat } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { return createBlockBuilder(blockBuilderStatus, expectedEntries, Long.BYTES); } @Override - public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) + public BlockBuilder createFixedSizeBlockBuilder(int positionCount) { return new LongArrayBlockBuilder(null, positionCount); } @@ -138,7 +134,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long value = block.getLong(position, 0); + long value = getLong(block, position); return SqlTimestampWithTimeZone.newInstance(getPrecision(), unpackMillisUtc(value), 0, unpackZoneKey(value)); } @@ -148,6 +144,12 @@ public int getFlatFixedSize() return Long.BYTES; } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java b/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java index fb77472356a6..2114679bf4b5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java @@ -19,8 +19,11 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.ShortArrayBlock; import io.trino.spi.block.ShortArrayBlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +59,7 @@ public final class SmallintType private SmallintType() { - super(new TypeSignature(StandardTypes.SMALLINT), long.class); + super(new TypeSignature(StandardTypes.SMALLINT), long.class, ShortArrayBlock.class); } @Override @@ -117,7 +120,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getShort(position, 0); + return getShort(block, position); } @Override @@ -161,7 +164,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((ShortArrayBlockBuilder) blockBuilder).writeShort(block.getShort(position, 0)); + ((ShortArrayBlockBuilder) blockBuilder).writeShort(getShort(block, position)); } } @@ -173,7 +176,7 @@ public long getLong(Block block, int position) public short getShort(Block block, int position) { - return block.getShort(position, 0); + return readShort((ShortArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -188,7 +191,7 @@ public void writeShort(BlockBuilder blockBuilder, short value) ((ShortArrayBlockBuilder) blockBuilder).writeShort(value); } - private void checkValueValid(long value) + private static void checkValueValid(long value) { if (value > Short.MAX_VALUE) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value %d exceeds MAX_SHORT", value)); @@ -217,6 +220,17 @@ public int hashCode() return getClass().hashCode(); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition ShortArrayBlock block, @BlockIndex int position) + { + return readShort(block, position); + } + + private static short readShort(ShortArrayBlock block, int position) + { + return block.getShort(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java index 43533598b0e7..1c66925c69c7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java @@ -98,7 +98,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return SqlTime.newInstance(precision, block.getLong(position, 0)); + return SqlTime.newInstance(precision, getLong(block, position)); } @ScalarOperator(READ_VALUE) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java index e3d0fbc706c6..ee9e406080bb 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static java.lang.String.format; @@ -59,9 +60,9 @@ public static TimeWithTimeZoneType createTimeWithTimeZoneType(int precision) return TYPES[precision]; } - protected TimeWithTimeZoneType(int precision, Class javaType) + protected TimeWithTimeZoneType(int precision, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.TIME_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType); + super(new TypeSignature(StandardTypes.TIME_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType, valueBlockType); this.precision = precision; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java index 5d6cd360371d..03749b781ade 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static java.lang.String.format; @@ -24,9 +25,10 @@ * @see ShortTimestampType * @see LongTimestampType */ -public abstract class TimestampType +public abstract sealed class TimestampType extends AbstractType implements FixedWidthType + permits LongTimestampType, ShortTimestampType { public static final int MAX_PRECISION = 12; @@ -57,9 +59,9 @@ public static TimestampType createTimestampType(int precision) return TYPES[precision]; } - TimestampType(int precision, Class javaType) + TimestampType(int precision, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.TIMESTAMP, TypeSignatureParameter.numericParameter(precision)), javaType); + super(new TypeSignature(StandardTypes.TIMESTAMP, TypeSignatureParameter.numericParameter(precision)), javaType, valueBlockType); this.precision = precision; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java index d900d47553c8..4f75e8176c5f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static java.lang.String.format; @@ -56,9 +57,9 @@ public static TimestampWithTimeZoneType createTimestampWithTimeZoneType(int prec return TYPES[precision]; } - TimestampWithTimeZoneType(int precision, Class javaType) + TimestampWithTimeZoneType(int precision, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.TIMESTAMP_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType); + super(new TypeSignature(StandardTypes.TIMESTAMP_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType, valueBlockType); if (precision < 0 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_PRECISION)); diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java index 10ed974602a2..b8b254eef3d4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java @@ -18,9 +18,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.ByteArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -52,7 +55,7 @@ public final class TinyintType private TinyintType() { - super(new TypeSignature(StandardTypes.TINYINT), long.class); + super(new TypeSignature(StandardTypes.TINYINT), long.class, ByteArrayBlock.class); } @Override @@ -113,7 +116,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getByte(position, 0); + return getByte(block, position); } @Override @@ -157,7 +160,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - writeByte(blockBuilder, block.getByte(position, 0)); + writeByte(blockBuilder, getByte(block, position)); } } @@ -169,7 +172,7 @@ public long getLong(Block block, int position) public byte getByte(Block block, int position) { - return block.getByte(position, 0); + return readByte((ByteArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -184,7 +187,7 @@ public void writeByte(BlockBuilder blockBuilder, byte value) ((ByteArrayBlockBuilder) blockBuilder).writeByte(value); } - private void checkValueValid(long value) + private static void checkValueValid(long value) { if (value > Byte.MAX_VALUE) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value %d exceeds MAX_BYTE", value)); @@ -212,6 +215,17 @@ public int hashCode() return getClass().hashCode(); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition ByteArrayBlock block, @BlockIndex int position) + { + return readByte(block, position); + } + + private static byte readByte(ByteArrayBlock block, int position) + { + return block.getByte(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/Type.java b/core/trino-spi/src/main/java/io/trino/spi/type/Type.java index ae4c2347a272..519abbfbb462 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/Type.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/Type.java @@ -18,6 +18,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import java.util.List; @@ -81,6 +82,11 @@ default TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOpe */ Class getJavaType(); + /** + * Gets the ValueBlock type used to store values of this type. + */ + Class getValueBlockType(); + /** * For parameterized types returns the list of parameters. */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java index 71ffeedca4a6..bc604ded7dcd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -46,6 +47,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -462,13 +465,18 @@ private void verifyMethodHandleSignature(int expectedArgumentCount, Class ret case BLOCK_POSITION_NOT_NULL: case BLOCK_POSITION: checkArgument(parameterType.equals(Block.class) && methodType.parameterType(parameterIndex + 1).equals(int.class), - "Expected BLOCK_POSITION argument have parameters Block and int"); + "Expected BLOCK_POSITION argument to have parameters Block and int"); + break; + case VALUE_BLOCK_POSITION_NOT_NULL: + case VALUE_BLOCK_POSITION: + checkArgument(Block.class.isAssignableFrom(parameterType) && methodType.parameterType(parameterIndex + 1).equals(int.class), + "Expected VALUE_BLOCK_POSITION argument to have parameters ValueBlock and int"); break; case FLAT: checkArgument(parameterType.equals(byte[].class) && methodType.parameterType(parameterIndex + 1).equals(int.class) && methodType.parameterType(parameterIndex + 2).equals(byte[].class), - "Expected FLAT argument have parameters byte[], int, and byte[]"); + "Expected FLAT argument to have parameters byte[], int, and byte[]"); break; case FUNCTION: throw new IllegalArgumentException("Function argument convention is not supported in type operators"); @@ -506,6 +514,10 @@ private void verifyMethodHandleSignature(int expectedArgumentCount, Class ret default: throw new UnsupportedOperationException("Unknown return convention: " + returnConvention); } + + if (operatorMethodHandle.getCallingConvention().getArgumentConventions().stream().anyMatch(argumentConvention -> argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL)) { + throw new IllegalArgumentException("BLOCK_POSITION argument convention is not allowed for type operators"); + } } private static InvocationConvention parseInvocationConvention(OperatorType operatorType, Class typeJavaType, Method method, Class expectedReturnType) @@ -576,11 +588,14 @@ private static InvocationArgumentConvention extractNextArgumentConvention( Method method) { if (isAnnotationPresent(parameterAnnotations.get(0), BlockPosition.class)) { - if (parameterTypes.size() > 1 && - isAnnotationPresent(parameterAnnotations.get(1), BlockIndex.class) && - parameterTypes.get(0).equals(Block.class) && - parameterTypes.get(1).equals(int.class)) { - return isAnnotationPresent(parameterAnnotations.get(0), SqlNullable.class) ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL; + if (parameterTypes.size() > 1 && isAnnotationPresent(parameterAnnotations.get(1), BlockIndex.class)) { + if (!ValueBlock.class.isAssignableFrom(parameterTypes.get(0))) { + throw new IllegalArgumentException("@BlockPosition argument must be a ValueBlock type for %s operator: %s".formatted(operatorType, method)); + } + if (parameterTypes.get(1) != int.class) { + throw new IllegalArgumentException("@BlockIndex argument must be type int for %s operator: %s".formatted(operatorType, method)); + } + return isAnnotationPresent(parameterAnnotations.get(0), SqlNullable.class) ? VALUE_BLOCK_POSITION : VALUE_BLOCK_POSITION_NOT_NULL; } } else if (isAnnotationPresent(parameterAnnotations.get(0), SqlNullable.class)) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java b/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java index fd599d361bf2..7e5cd97a39ed 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java @@ -18,6 +18,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import jakarta.annotation.Nullable; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; @@ -60,11 +61,11 @@ public static Object readNativeValue(Type type, Block block, int position) return type.getObject(block, position); } - public static Block writeNativeValue(Type type, @Nullable Object value) + public static ValueBlock writeNativeValue(Type type, @Nullable Object value) { BlockBuilder blockBuilder = type.createBlockBuilder(null, 1); writeNativeValue(type, blockBuilder, value); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } /** diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java b/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java index 2897b79c785f..228ad712bff7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -61,7 +62,7 @@ public class UuidType private UuidType() { - super(new TypeSignature(StandardTypes.UUID), Slice.class); + super(new TypeSignature(StandardTypes.UUID), Slice.class, Int128ArrayBlock.class); } @Override @@ -121,8 +122,10 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - long high = reverseBytes(block.getLong(position, 0)); - long low = reverseBytes(block.getLong(position, SIZE_OF_LONG)); + Int128ArrayBlock valueBlock = (Int128ArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + long high = reverseBytes(valueBlock.getInt128High(valuePosition)); + long low = reverseBytes(valueBlock.getInt128Low(valuePosition)); return new UUID(high, low).toString(); } @@ -133,9 +136,9 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((Int128ArrayBlockBuilder) blockBuilder).writeInt128( - block.getLong(position, 0), - block.getLong(position, SIZE_OF_LONG)); + Int128ArrayBlock valueBlock = (Int128ArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((Int128ArrayBlockBuilder) blockBuilder).writeInt128(valueBlock.getInt128High(valuePosition), valueBlock.getInt128Low(valuePosition)); } } @@ -159,10 +162,7 @@ public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int l @Override public final Slice getSlice(Block block, int position) { - Slice value = Slices.allocate(INT128_BYTES); - value.setLong(0, block.getLong(position, 0)); - value.setLong(SIZE_OF_LONG, block.getLong(position, SIZE_OF_LONG)); - return value; + return read((Int128ArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -189,6 +189,15 @@ public static UUID trinoUuidToJavaUuid(Slice uuid) reverseBytes(uuid.getLong(SIZE_OF_LONG))); } + @ScalarOperator(READ_VALUE) + private static Slice read(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) + { + Slice value = Slices.allocate(INT128_BYTES); + value.setLong(0, block.getInt128High(position)); + value.setLong(SIZE_OF_LONG, block.getInt128Low(position)); + return value; + } + @ScalarOperator(READ_VALUE) private static Slice readFlat( @FlatFixed byte[] fixedSizeSlice, @@ -232,13 +241,13 @@ private static boolean equalOperator(Slice left, Slice right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return equal( - leftBlock.getLong(leftPosition, 0), - leftBlock.getLong(leftPosition, SIZE_OF_LONG), - rightBlock.getLong(rightPosition, 0), - rightBlock.getLong(rightPosition, SIZE_OF_LONG)); + leftBlock.getInt128High(leftPosition), + leftBlock.getInt128Low(leftPosition), + rightBlock.getInt128High(rightPosition), + rightBlock.getInt128Low(rightPosition)); } private static boolean equal(long leftLow, long leftHigh, long rightLow, long rightHigh) @@ -253,9 +262,9 @@ private static long xxHash64Operator(Slice value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) { - return xxHash64(block.getLong(position, 0), block.getLong(position, SIZE_OF_LONG)); + return xxHash64(block.getInt128High(position), block.getInt128Low(position)); } private static long xxHash64(long low, long high) @@ -274,13 +283,13 @@ private static long comparisonOperator(Slice left, Slice right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return compareLittleEndian( - leftBlock.getLong(leftPosition, 0), - leftBlock.getLong(leftPosition, SIZE_OF_LONG), - rightBlock.getLong(rightPosition, 0), - rightBlock.getLong(rightPosition, SIZE_OF_LONG)); + leftBlock.getInt128High(leftPosition), + leftBlock.getInt128Low(leftPosition), + rightBlock.getInt128High(rightPosition), + rightBlock.getInt128Low(rightPosition)); } private static int compareLittleEndian(long leftLow64le, long leftHigh64le, long rightLow64le, long rightHigh64le) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java b/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java index e001c983af65..07f192cb83cf 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -69,13 +70,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java b/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java index f69c769b5c0f..02aa8ff06d2e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -132,7 +133,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); + Slice slice = getSlice(block, position); if (!isUnbounded() && countCodePoints(slice) > length) { throw new IllegalArgumentException(format("Character count exceeds length limit %s: %s", length, sliceRepresentation(slice))); } @@ -161,7 +162,7 @@ public Optional getRange() if (!cachedRangePresent) { if (length > 100) { // The max/min values may be materialized in the plan, so we don't want them to be too large. - // Range comparison against large values are usually nonsensical, too, so no need to support them + // Range comparison against large values is usually nonsensical, too, so no need to support them // beyond a certain size. They specific choice above is arbitrary and can be adjusted if needed. range = Optional.empty(); } @@ -184,7 +185,9 @@ public Optional getRange() @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } public void writeString(BlockBuilder blockBuilder, String value) diff --git a/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java b/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java index 88d341cea864..540af7561c19 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java +++ b/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java @@ -66,21 +66,21 @@ public void testNestedGetLoadedBlock() List actualNotifications = new ArrayList<>(); Block arrayBlock = new IntArrayBlock(1, Optional.empty(), new int[] {0}); LazyBlock lazyArrayBlock = new LazyBlock(1, () -> arrayBlock); - Block dictionaryBlock = DictionaryBlock.create(2, lazyArrayBlock, new int[] {0, 0}); - LazyBlock lazyBlock = new LazyBlock(2, () -> dictionaryBlock); + Block rowBlock = RowBlock.fromFieldBlocks(2, Optional.empty(), new Block[]{lazyArrayBlock}); + LazyBlock lazyBlock = new LazyBlock(2, () -> rowBlock); LazyBlock.listenForLoads(lazyBlock, actualNotifications::add); Block loadedBlock = lazyBlock.getBlock(); - assertThat(loadedBlock).isInstanceOf(DictionaryBlock.class); - assertThat(((DictionaryBlock) loadedBlock).getDictionary()).isInstanceOf(LazyBlock.class); + assertThat(loadedBlock).isInstanceOf(RowBlock.class); + assertThat(((RowBlock) loadedBlock).getRawFieldBlocks().get(0)).isInstanceOf(LazyBlock.class); assertThat(actualNotifications).isEqualTo(ImmutableList.of(loadedBlock)); Block fullyLoadedBlock = lazyBlock.getLoadedBlock(); - assertThat(fullyLoadedBlock).isInstanceOf(DictionaryBlock.class); - assertThat(((DictionaryBlock) fullyLoadedBlock).getDictionary()).isInstanceOf(IntArrayBlock.class); + assertThat(fullyLoadedBlock).isInstanceOf(RowBlock.class); + assertThat(((RowBlock) fullyLoadedBlock).getRawFieldBlocks().get(0)).isInstanceOf(IntArrayBlock.class); assertThat(actualNotifications).isEqualTo(ImmutableList.of(loadedBlock, arrayBlock)); assertThat(lazyBlock.isLoaded()).isTrue(); - assertThat(dictionaryBlock.isLoaded()).isTrue(); + assertThat(rowBlock.isLoaded()).isTrue(); } private static void assertNotificationsRecursive(int depth, Block lazyBlock, List actualNotifications, List expectedNotifications) diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index 678867e35749..ac0ba86d87d4 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -19,8 +19,12 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.spi.TrinoException; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Fixed12Block; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.type.ArrayType; @@ -36,6 +40,7 @@ import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.BitSet; +import java.util.EnumSet; import java.util.List; import java.util.stream.IntStream; @@ -52,6 +57,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -199,6 +206,58 @@ public void testAdaptFromBlockPositionNotNullObjects() verifyAllAdaptations(actualConvention, "blockPositionObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } + @Test + public void testAdaptFromValueBlockPosition() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPosition"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromValueBlockPositionObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPositionObjects"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromValueBlockPositionNotNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPosition"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromValueBlockPositionObjectsNotNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPositionObjects"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + private static void verifyAllAdaptations( InvocationConvention actualConvention, String methodName, @@ -219,7 +278,7 @@ private static void verifyAllAdaptations( throws Throwable { List> allArgumentConventions = allCombinations( - ImmutableList.of(NEVER_NULL, BLOCK_POSITION_NOT_NULL, BOXED_NULLABLE, NULL_FLAG, BLOCK_POSITION, FLAT, IN_OUT), + ImmutableList.of(NEVER_NULL, BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, BOXED_NULLABLE, NULL_FLAG, BLOCK_POSITION, VALUE_BLOCK_POSITION, FLAT, IN_OUT), argumentTypes.size()); for (List argumentConventions : allArgumentConventions) { for (InvocationReturnConvention returnConvention : InvocationReturnConvention.values()) { @@ -258,7 +317,8 @@ private static void adaptAndVerify( assertThat(expectedConvention.getReturnConvention() == FAIL_ON_NULL || expectedConvention.getReturnConvention() == FLAT_RETURN).isTrue(); return; } - if (actualConvention.getArgumentConventions().stream().anyMatch(convention -> convention == BLOCK_POSITION || convention == BLOCK_POSITION_NOT_NULL)) { + if (actualConvention.getArgumentConventions().stream() + .anyMatch(convention -> EnumSet.of(BLOCK_POSITION, BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION, VALUE_BLOCK_POSITION_NOT_NULL).contains(convention))) { return; } } @@ -343,7 +403,7 @@ private static boolean canCallConventionWithNullArguments(InvocationConvention c { for (int i = 0; i < convention.getArgumentConventions().size(); i++) { InvocationArgumentConvention argumentConvention = convention.getArgumentConvention(i); - if (nullArguments.get(i) && (argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == FLAT)) { + if (nullArguments.get(i) && EnumSet.of(NEVER_NULL, BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, FLAT).contains(argumentConvention)) { return false; } } @@ -382,6 +442,10 @@ private static List> toCallArgumentTypes(InvocationConvention callingCo expectedArguments.add(Block.class); expectedArguments.add(int.class); } + case VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION -> { + expectedArguments.add(argumentType.getValueBlockType()); + expectedArguments.add(int.class); + } case FLAT -> { expectedArguments.add(Slice.class); expectedArguments.add(int.class); @@ -423,21 +487,31 @@ private static List toCallArgumentValues(InvocationConvention callingCon callArguments.add(testValue == null ? Defaults.defaultValue(argumentType.getJavaType()) : testValue); callArguments.add(testValue == null); } - case BLOCK_POSITION_NOT_NULL -> { + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL -> { verify(testValue != null, "null cannot be passed to a block positions not null argument"); BlockBuilder blockBuilder = argumentType.createBlockBuilder(null, 3); blockBuilder.appendNull(); writeNativeValue(argumentType, blockBuilder, testValue); blockBuilder.appendNull(); - callArguments.add(blockBuilder.build()); + if (argumentConvention == BLOCK_POSITION_NOT_NULL) { + callArguments.add(blockBuilder.build()); + } + else { + callArguments.add(blockBuilder.buildValueBlock()); + } callArguments.add(1); } - case BLOCK_POSITION -> { + case BLOCK_POSITION, VALUE_BLOCK_POSITION -> { BlockBuilder blockBuilder = argumentType.createBlockBuilder(null, 3); blockBuilder.appendNull(); writeNativeValue(argumentType, blockBuilder, testValue); blockBuilder.appendNull(); - callArguments.add(blockBuilder.build()); + if (argumentConvention == BLOCK_POSITION) { + callArguments.add(blockBuilder.build()); + } + else { + callArguments.add(blockBuilder.buildValueBlock()); + } callArguments.add(1); } case FLAT -> { @@ -736,6 +810,80 @@ public boolean blockPositionObjects( return true; } + @SuppressWarnings("unused") + public boolean valueBlockPosition( + LongArrayBlock doubleBlock, int doublePosition, + VariableWidthBlock sliceBlock, int slicePosition, + ArrayBlock blockBlock, int blockPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = false; + + if (doubleBlock.isNull(doublePosition)) { + this.doubleValue = null; + } + else { + this.doubleValue = DOUBLE.getDouble(doubleBlock, doublePosition); + } + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + return true; + } + + @SuppressWarnings("unused") + public boolean valueBlockPositionObjects( + VariableWidthBlock sliceBlock, int slicePosition, + ArrayBlock blockBlock, int blockPosition, + VariableWidthBlock objectCharBlock, int objectCharPosition, + Fixed12Block objectTimestampBlock, int objectTimestampPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = true; + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + + if (objectCharBlock.isNull(objectCharPosition)) { + this.objectCharValue = null; + } + else { + this.objectCharValue = CHAR_TYPE.getObject(objectCharBlock, objectCharPosition); + } + + if (objectTimestampBlock.isNull(objectTimestampPosition)) { + this.objectTimestampValue = null; + } + else { + this.objectTimestampValue = TIMESTAMP_TYPE.getObject(objectTimestampBlock, objectTimestampPosition); + } + return true; + } + public void verify( InvocationConvention actualConvention, BitSet nullArguments, @@ -781,7 +929,7 @@ private static boolean shouldFunctionBeInvoked(InvocationConvention actualConven { for (int i = 0; i < actualConvention.getArgumentConventions().size(); i++) { InvocationArgumentConvention argumentConvention = actualConvention.getArgumentConvention(i); - if ((argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == FLAT) && nullArguments.get(i)) { + if ((argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == VALUE_BLOCK_POSITION_NOT_NULL || argumentConvention == FLAT) && nullArguments.get(i)) { return false; } } diff --git a/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java b/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java index 723396c57120..1d60f9c89932 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java +++ b/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java @@ -40,7 +40,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getLong(position, 0); + return getLong(block, position); } @Override diff --git a/docs/pom.xml b/docs/pom.xml index 514367efd30c..8d8940ef99f1 100644 --- a/docs/pom.xml +++ b/docs/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT trino-docs diff --git a/docs/src/main/sphinx/admin/event-listeners-http.md b/docs/src/main/sphinx/admin/event-listeners-http.md index 0707713cefff..6eeea6d51853 100644 --- a/docs/src/main/sphinx/admin/event-listeners-http.md +++ b/docs/src/main/sphinx/admin/event-listeners-http.md @@ -14,6 +14,7 @@ and metadata about the query processing. Running the capture system separate from Trino reduces the performance impact and avoids downtime for non-client-facing changes. +(http-event-listener-requirements)= ## Requirements You need to perform the following steps: @@ -46,62 +47,61 @@ event-listener.config-files=etc/http-event-listener.properties,... ### Configuration properties -```{eval-rst} -.. list-table:: - :widths: 40, 40, 20 - :header-rows: 1 - - * - Property name - - Description - - Default - - * - http-event-listener.log-created - - Enable the plugin to log ``QueryCreatedEvent`` events - - ``false`` - - * - http-event-listener.log-completed - - Enable the plugin to log ``QueryCompletedEvent`` events - - ``false`` - - * - http-event-listener.log-split - - Enable the plugin to log ``SplitCompletedEvent`` events - - ``false`` - - * - http-event-listener.connect-ingest-uri - - The URI that the plugin will POST events to - - None. See the `requirements <#requirements>`_ section. - - * - http-event-listener.connect-http-headers - - List of custom HTTP headers to be sent along with the events. See - :ref:`http-event-listener-custom-headers` for more details - - Empty - - * - http-event-listener.connect-retry-count - - The number of retries on server error. A server is considered to be - in an error state when the response code is 500 or higher - - ``0`` - - * - http-event-listener.connect-retry-delay - - Duration for which to delay between attempts to send a request - - ``1s`` - - * - http-event-listener.connect-backoff-base - - The base used for exponential backoff when retrying on server error. - The formula used to calculate the delay is - :math:`attemptDelay = retryDelay * backoffBase^{attemptCount}`. - Attempt count starts from 0. Leave this empty or set to 1 to disable - exponential backoff and keep constant delays - - ``2`` - - * - http-event-listener.connect-max-delay - - The upper bound of a delay between 2 retries. This should be - used with exponential backoff. - - ``1m`` - - * - http-event-listener.* - - Pass configuration onto the HTTP client - - -``` +:::{list-table} +:widths: 40, 40, 20 +:header-rows: 1 + +* - Property name + - Description + - Default + +* - http-event-listener.log-created + - Enable the plugin to log `QueryCreatedEvent` events + - `false` + +* - http-event-listener.log-completed + - Enable the plugin to log `QueryCompletedEvent` events + - `false` + +* - http-event-listener.log-split + - Enable the plugin to log `SplitCompletedEvent` events + - `false` + +* - http-event-listener.connect-ingest-uri + - The URI that the plugin will POST events to + - None. See the [requirements](http-event-listener-requirements) section. + +* - http-event-listener.connect-http-headers + - List of custom HTTP headers to be sent along with the events. See + [](http-event-listener-custom-headers) for more details + - Empty + +* - http-event-listener.connect-retry-count + - The number of retries on server error. A server is considered to be + in an error state when the response code is 500 or higher + - `0` + +* - http-event-listener.connect-retry-delay + - Duration for which to delay between attempts to send a request + - `1s` + +* - http-event-listener.connect-backoff-base + - The base used for exponential backoff when retrying on server error. + The formula used to calculate the delay is + `attemptDelay = retryDelay * backoffBase^{attemptCount}`. + Attempt count starts from 0. Leave this empty or set to 1 to disable + exponential backoff and keep constant delays + - `2` + +* - http-event-listener.connect-max-delay + - The upper bound of a delay between 2 retries. This should be + used with exponential backoff. + - `1m` + +* - http-event-listener.* + - Pass configuration onto the HTTP client + - +::: (http-event-listener-custom-headers)= diff --git a/docs/src/main/sphinx/admin/fault-tolerant-execution.md b/docs/src/main/sphinx/admin/fault-tolerant-execution.md index 3361f27b676e..54185a71f3b8 100644 --- a/docs/src/main/sphinx/admin/fault-tolerant-execution.md +++ b/docs/src/main/sphinx/admin/fault-tolerant-execution.md @@ -55,33 +55,32 @@ connector. The following connectors support fault-tolerant execution: The following configuration properties control the behavior of fault-tolerant execution on a Trino cluster: -```{eval-rst} -.. list-table:: Fault-tolerant execution configuration properties - :widths: 30, 50, 20 - :header-rows: 1 - - * - Property name - - Description - - Default value - * - ``retry-policy`` - - Configures what is retried in the event of failure, either - ``QUERY`` to retry the whole query, or ``TASK`` to retry tasks - individually if they fail. See :ref:`retry policy ` for - more information. - - ``NONE`` - * - ``exchange.deduplication-buffer-size`` - - :ref:`Data size ` of the coordinator's in-memory - buffer used by fault-tolerant execution to store output of query - :ref:`stages `. If this buffer is filled during - query execution, the query fails with a "Task descriptor storage capacity - has been exceeded" error message unless an :ref:`exchange manager - ` is configured. - - ``32MB`` - * - ``exchange.compression-enabled`` - - Enable compression of spooling data. Setting to ``true`` is recommended - when using an :ref:`exchange manager `. - - ``false`` -``` + +:::{list-table} Fault-tolerant execution configuration properties +:widths: 30, 50, 20 +:header-rows: 1 + +* - Property name + - Description + - Default value +* - `retry-policy` + - Configures what is retried in the event of failure, either `QUERY` to retry + the whole query, or `TASK` to retry tasks individually if they fail. See + [retry policy](fte-retry-policy) for more information. + - `NONE` +* - `exchange.deduplication-buffer-size` + - [Data size](prop-type-data-size) of the coordinator's in-memory buffer used + by fault-tolerant execution to store output of query + [stages](trino-concept-stage). If this buffer is filled during query + execution, the query fails with a "Task descriptor storage capacity has been + exceeded" error message unless an [exchange manager](fte-exchange-manager) + is configured. + - `32MB` +* - `exchange.compression-enabled` + - Enable compression of spooling data. Setting to `true` is recommended + when using an [exchange manager](fte-exchange-manager). + - ``false`` +::: (fte-retry-policy)= @@ -158,46 +157,44 @@ troubleshooting purposes. The following configuration properties control the thresholds at which queries/tasks are no longer retried in the event of repeated failures: -```{eval-rst} -.. list-table:: Fault tolerance retry limit configuration properties - :widths: 30, 50, 20, 30 - :header-rows: 1 - - * - Property name - - Description - - Default value - - Retry policy - * - ``query-retry-attempts`` - - Maximum number of times Trino may attempt to retry a query before - declaring the query as failed. - - ``4`` - - Only ``QUERY`` - * - ``task-retry-attempts-per-task`` - - Maximum number of times Trino may attempt to retry a single task before - declaring the query as failed. - - ``4`` - - Only ``TASK`` - * - ``retry-initial-delay`` - - Minimum :ref:`time ` that a failed query or task must - wait before it is retried. May be overridden with the - ``retry_initial_delay`` :ref:`session property - `. - - ``10s`` - - ``QUERY`` and ``TASK`` - * - ``retry-max-delay`` - - Maximum :ref:`time ` that a failed query or task must - wait before it is retried. Wait time is increased on each subsequent - failure. May be overridden with the ``retry_max_delay`` :ref:`session - property `. - - ``1m`` - - ``QUERY`` and ``TASK`` - * - ``retry-delay-scale-factor`` - - Factor by which retry delay is increased on each query or task failure. - May be overridden with the ``retry_delay_scale_factor`` :ref:`session - property `. - - ``2.0`` - - ``QUERY`` and ``TASK`` -``` +:::{list-table} Fault tolerance retry limit configuration properties +:widths: 30, 50, 20, 30 +:header-rows: 1 + +* - Property name + - Description + - Default value + - Retry policy +* - `query-retry-attempts` + - Maximum number of times Trino may attempt to retry a query before declaring + the query as failed. + - `4` + - Only `QUERY` +* - `task-retry-attempts-per-task` + - Maximum number of times Trino may attempt to retry a single task before + declaring the query as failed. + - `4` + - Only `TASK` +* - `retry-initial-delay` + - Minimum [time](prop-type-duration) that a failed query or task must wait + before it is retried. May be overridden with the `retry_initial_delay` + [session property](session-properties-definition). + - `10s` + - `QUERY` and `TASK` +* - `retry-max-delay` + - Maximum :ref:`time ` that a failed query or task must + wait before it is retried. Wait time is increased on each subsequent + failure. May be overridden with the ``retry_max_delay`` [session + property](session-properties-definition). + - `1m` + - `QUERY` and `TASK` +* - `retry-delay-scale-factor` + - Factor by which retry delay is increased on each query or task failure. May + be overridden with the `retry_delay_scale_factor` [session + property](session-properties-definition). + - `2.0` + - `QUERY` and `TASK` +::: ### Task sizing @@ -213,85 +210,82 @@ during fault-tolerant task execution, you can configure the following configuration properties to manually control task sizing. These configuration properties only apply to a `TASK` retry policy. -```{eval-rst} -.. list-table:: Task sizing configuration properties - :widths: 30, 50, 20 - :header-rows: 1 - - * - Property name - - Description - - Default value - * - ``fault-tolerant-execution-standard-split-size`` - - Standard :ref:`split ` :ref:`data size - ` processed by tasks that read data from source - tables. Value is interpreted with split weight taken into account. If the - weight of splits produced by a catalog denotes that they are lighter or - heavier than "standard" split, then the number of splits processed by a - single task is adjusted accordingly. - - May be overridden for the current session with the - ``fault_tolerant_execution_standard_split_size`` - :ref:`session property `. - - ``64MB`` - * - ``fault-tolerant-execution-max-task-split-count`` - - Maximum number of :ref:`splits ` processed by a - single task. This value is not split weight-adjusted and serves as - protection against situations where catalogs report an incorrect split - weight. - - May be overridden for the current session with the - ``fault_tolerant_execution_max_task_split_count`` - :ref:`session property `. - - ``256`` - * - ``fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-growth-period`` - - The number of tasks created for any given non-writer stage of arbitrary - distribution before task size is increased. - - ``64`` - * - ``fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-growth-factor`` - - Growth factor for adaptive sizing of non-writer tasks of arbitrary - distribution for fault-tolerant execution. Lower bound is 1.0. For every - task size increase, new task target size is old task target size - multiplied by this growth factor. - - ``1.26`` - * - ``fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-min`` - - Initial/minimum target input :ref:`data size ` for - non-writer tasks of arbitrary distribution of fault-tolerant execution. - - ``512MB`` - * - ``fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-max`` - - Maximum target input :ref:`data size ` for each - non-writer task of arbitrary distribution of fault-tolerant execution. - - ``50GB`` - * - ``fault-tolerant-execution-arbitrary-distribution-write-task-target-size-growth-period`` - - The number of tasks created for any given writer stage of arbitrary - distribution before task size is increased. - - ``64`` - * - ``fault-tolerant-execution-arbitrary-distribution-write-task-target-size-growth-factor`` - - Growth factor for adaptive sizing of writer tasks of arbitrary - distribution for fault-tolerant execution. Lower bound is 1.0. For every - task size increase, new task target size is old task target size - multiplied by this growth factor. - - ``1.26`` - * - ``fault-tolerant-execution-arbitrary-distribution-write-task-target-size-min`` - - Initial/minimum target input :ref:`data size ` for - writer tasks of arbitrary distribution of fault-tolerant execution. - - ``4GB`` - * - ``fault-tolerant-execution-arbitrary-distribution-write-task-target-size-max`` - - Maximum target input :ref:`data size ` for writer - tasks of arbitrary distribution of fault-tolerant execution. - - ``50GB`` - * - ``fault-tolerant-execution-hash-distribution-compute-task-target-size`` - - Target input :ref:`data size ` for non-writer tasks - of hash distribution of fault-tolerant execution. - - ``512MB`` - * - ``fault-tolerant-execution-hash-distribution-write-task-target-size`` - - Target input :ref:`data size ` of writer tasks of - hash distribution of fault-tolerant execution. - - ``4GB`` - * - ``fault-tolerant-execution-hash-distribution-write-task-target-max-count`` - - Soft upper bound on number of writer tasks in a stage of hash - distribution of fault-tolerant execution. - - ``2000`` -``` +:::{list-table} Task sizing configuration properties +:widths: 30, 50, 20 +:header-rows: 1 + +* - Property name + - Description + - Default value +* - `fault-tolerant-execution-standard-split-size` + - Standard [split](trino-concept-splits) [data size]( prop-type-data-size) + processed by tasks that read data from source tables. Value is interpreted + with split weight taken into account. If the weight of splits produced by a + catalog denotes that they are lighter or heavier than "standard" split, then + the number of splits processed by a single task is adjusted accordingly. + + May be overridden for the current session with the + `fault_tolerant_execution_standard_split_size` [session + property](session-properties-definition). + - `64MB` +* - `fault-tolerant-execution-max-task-split-count` + - Maximum number of [splits](trino-concept-splits) processed by a single task. + This value is not split weight-adjusted and serves as protection against + situations where catalogs report an incorrect split weight. + + May be overridden for the current session with the + `fault_tolerant_execution_max_task_split_count` [session + property](session-properties-definition). + - `256` +* - `fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-growth-period` + - The number of tasks created for any given non-writer stage of arbitrary + distribution before task size is increased. + - `64` +* - `fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-growth-factor` + - Growth factor for adaptive sizing of non-writer tasks of arbitrary + distribution for fault-tolerant execution. Lower bound is 1.0. For every + task size increase, new task target size is old task target size multiplied + by this growth factor. + - `1.26` +* - `fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-min` + - Initial/minimum target input [data size](prop-type-data-size) for non-writer + tasks of arbitrary distribution of fault-tolerant execution. + - `512MB` +* - `fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-max` + - Maximum target input [data size](prop-type-data-size) for each non-writer + task of arbitrary distribution of fault-tolerant execution. + - `50GB` +* - `fault-tolerant-execution-arbitrary-distribution-write-task-target-size-growth-period` + - The number of tasks created for any given writer stage of arbitrary + distribution before task size is increased. + - `64` +* - `fault-tolerant-execution-arbitrary-distribution-write-task-target-size-growth-factor` + - Growth factor for adaptive sizing of writer tasks of arbitrary distribution + for fault-tolerant execution. Lower bound is 1.0. For every task size + increase, new task target size is old task target size multiplied by this + growth factor. + - `1.26` +* - `fault-tolerant-execution-arbitrary-distribution-write-task-target-size-min` + - Initial/minimum target input [data size](prop-type-data-size) for writer + tasks of arbitrary distribution of fault-tolerant execution. + - `4GB` +* - `fault-tolerant-execution-arbitrary-distribution-write-task-target-size-max` + - Maximum target input [data size](prop-type-data-size) for writer tasks of + arbitrary distribution of fault-tolerant execution. + - `50GB` +* - `fault-tolerant-execution-hash-distribution-compute-task-target-size` + - Target input [data size](prop-type-data-size) for non-writer tasks of hash + distribution of fault-tolerant execution. + - `512MB` +* - `fault-tolerant-execution-hash-distribution-write-task-target-size` + - Target input [data size](prop-type-data-size) of writer tasks of hash + distribution of fault-tolerant execution. + - ``4GB`` +* - `fault-tolerant-execution-hash-distribution-write-task-target-max-count` + - Soft upper bound on number of writer tasks in a stage of hash distribution + of fault-tolerant execution. + - `2000` +::: ### Node allocation @@ -304,79 +298,76 @@ The initial task memory-requirements estimation is static and configured with the `fault-tolerant-task-memory` configuration property. This property only applies to a `TASK` retry policy. -```{eval-rst} -.. list-table:: Node allocation configuration properties - :widths: 30, 50, 20 - :header-rows: 1 - - * - Property name - - Description - - Default value - * - ``fault-tolerant-execution-task-memory`` - - Initial task memory :ref:`data size ` estimation - used for bin-packing when allocating nodes for tasks. May be overridden - for the current session with the - ``fault_tolerant_execution_task_memory`` - :ref:`session property `. - - ``5GB`` -``` +:::{list-table} Node allocation configuration properties +:widths: 30, 50, 20 +:header-rows: 1 + +* - Property name + - Description + - Default value +* - `fault-tolerant-execution-task-memory` + - Initial task memory [data size](prop-type-data-size) estimation + used for bin-packing when allocating nodes for tasks. May be overridden + for the current session with the + `fault_tolerant_execution_task_memory` + [session property](session-properties-definition). + - `5GB` +::: ### Other tuning The following additional configuration property can be used to manage fault-tolerant execution: -```{eval-rst} -.. list-table:: Other fault-tolerant execution configuration properties - :widths: 30, 50, 20, 30 - :header-rows: 1 - - * - Property name - - Description - - Default value - - Retry policy - * - ``fault-tolerant-execution-task-descriptor-storage-max-memory`` - - Maximum :ref:`data size ` of memory to be used to - store task descriptors for fault tolerant queries on coordinator. Extra - memory is needed to be able to reschedule tasks in case of a failure. - - (JVM heap size * 0.15) - - Only ``TASK`` - * - ``fault-tolerant-execution-max-partition-count`` - - Maximum number of partitions to use for distributed joins and - aggregations, similar in function to the - ``query.max-hash-partition-count`` :doc:`query management property - `. It is not recommended to increase - this property value above the default of ``50``, which may result in - instability and poor performance. May be overridden for the current - session with the ``fault_tolerant_execution_max_partition_count`` - :ref:`session property `. - - ``50`` - - Only ``TASK`` - * - ``fault-tolerant-execution-min-partition-count`` - - Minimum number of partitions to use for distributed joins and - aggregations, similar in function to the - ``query.min-hash-partition-count`` :doc:`query management property - `. May be overridden for the current - session with the ``fault_tolerant_execution_min_partition_count`` - :ref:`session property `. - - ``4`` - - Only ``TASK`` - * - ``fault-tolerant-execution-min-partition-count-for-write`` - - Minimum number of partitions to use for distributed joins and - aggregations in write queries, similar in function to the - ``query.min-hash-partition-count-for-write`` :doc:`query management - property `. May be overridden for - the current session with the - ``fault_tolerant_execution_min_partition_count_for_write`` - :ref:`session property `. - - ``50`` - - Only ``TASK`` - * - ``max-tasks-waiting-for-node-per-stage`` - - Allow for up to configured number of tasks to wait for node allocation - per stage, before pausing scheduling for other tasks from this stage. - - 5 - - Only ``TASK`` -``` +:::{list-table} Other fault-tolerant execution configuration properties +:widths: 30, 50, 20, 30 +:header-rows: 1 + +* - Property name + - Description + - Default value + - Retry policy +* - `fault-tolerant-execution-task-descriptor-storage-max-memory` + - Maximum [data size](prop-type-data-size) of memory to be used to + store task descriptors for fault tolerant queries on coordinator. Extra + memory is needed to be able to reschedule tasks in case of a failure. + - (JVM heap size * 0.15) + - Only `TASK` +* - `fault-tolerant-execution-max-partition-count` + - Maximum number of partitions to use for distributed joins and aggregations, + similar in function to the ``query.max-hash-partition-count`` [query + management property](/admin/properties-query-management). It is not + recommended to increase this property value above the default of `50`, which + may result in instability and poor performance. May be overridden for the + current session with the `fault_tolerant_execution_max_partition_count` + [session property](session-properties-definition). + - `50` + - Only `TASK` +* - `fault-tolerant-execution-min-partition-count` + - Minimum number of partitions to use for distributed joins and aggregations, + similar in function to the `query.min-hash-partition-count` [query + management property](/admin/properties-query-management). May be overridden + for the current session with the + `fault_tolerant_execution_min_partition_count` [session + property](session-properties-definition). + - `4` + - Only `TASK` +* - `fault-tolerant-execution-min-partition-count-for-write` + - Minimum number of partitions to use for distributed joins and aggregations + in write queries, similar in function to the + `query.min-hash-partition-count-for-write` [query management + property](/admin/properties-query-management). May be overridden for the + current session with the + `fault_tolerant_execution_min_partition_count_for_write` [session + property](session-properties-definition). + - `50` + - Only `TASK` +* - `max-tasks-waiting-for-node-per-stage` + - Allow for up to configured number of tasks to wait for node allocation + per stage, before pausing scheduling for other tasks from this stage. + - 5 + - Only `TASK` +::: (fte-exchange-manager)= @@ -401,118 +392,115 @@ The following table lists the available configuration properties for `exchange-manager.properties`, their default values, and which filesystem(s) the property may be configured for: -```{eval-rst} -.. list-table:: Exchange manager configuration properties - :widths: 30, 50, 20, 30 - :header-rows: 1 - - * - Property name - - Description - - Default value - - Supported filesystem - * - ``exchange.base-directories`` - - Comma-separated list of URI locations that the exchange manager uses to - store spooling data. - - - - Any - * - ``exchange.sink-buffer-pool-min-size`` - - The minimum buffer pool size for an exchange sink. The larger the buffer - pool size, the larger the write parallelism and memory usage. - - ``10`` - - Any - * - ``exchange.sink-buffers-per-partition`` - - The number of buffers per partition in the buffer pool. The larger the - buffer pool size, the larger the write parallelism and memory usage. - - ``2`` - - Any - * - ``exchange.sink-max-file-size`` - - Max :ref:`data size ` of files written by exchange - sinks. - - ``1GB`` - - Any - * - ``exchange.source-concurrent-readers`` - - Number of concurrent readers to read from spooling storage. The - larger the number of concurrent readers, the larger the read parallelism - and memory usage. - - ``4`` - - Any - * - ``exchange.s3.aws-access-key`` - - AWS access key to use. Required for a connection to AWS S3 and GCS, can - be ignored for other S3 storage systems. - - - - AWS S3, GCS - * - ``exchange.s3.aws-secret-key`` - - AWS secret key to use. Required for a connection to AWS S3 and GCS, can - be ignored for other S3 storage systems. - - - - AWS S3, GCS - * - ``exchange.s3.iam-role`` - - IAM role to assume. - - - - AWS S3, GCS - * - ``exchange.s3.external-id`` - - External ID for the IAM role trust policy. - - - - AWS S3, GCS - * - ``exchange.s3.region`` - - Region of the S3 bucket. - - - - AWS S3, GCS - * - ``exchange.s3.endpoint`` - - S3 storage endpoint server if using an S3-compatible storage system that - is not AWS. If using AWS S3, this can be ignored. If using GCS, set it - to ``https://storage.googleapis.com``. - - - - Any S3-compatible storage - * - ``exchange.s3.max-error-retries`` - - Maximum number of times the exchange manager's S3 client should retry - a request. - - ``10`` - - Any S3-compatible storage - * - ``exchange.s3.path-style-access`` - - Enables using `path-style access `_ - for all requests to S3. - - ``false`` - - Any S3-compatible storage - * - ``exchange.s3.upload.part-size`` - - Part :ref:`data size ` for S3 multi-part upload. - - ``5MB`` - - Any S3-compatible storage - * - ``exchange.gcs.json-key-file-path`` - - Path to the JSON file that contains your Google Cloud Platform - service account key. Not to be set together with - ``exchange.gcs.json-key`` - - - - GCS - * - ``exchange.gcs.json-key`` - - Your Google Cloud Platform service account key in JSON format. - Not to be set together with ``exchange.gcs.json-key-file-path`` - - - - GCS - * - ``exchange.azure.connection-string`` - - Connection string used to access the spooling container. - - - - Azure Blob Storage - * - ``exchange.azure.block-size`` - - Block :ref:`data size ` for Azure block blob - parallel upload. - - ``4MB`` - - Azure Blob Storage - * - ``exchange.azure.max-error-retries`` - - Maximum number of times the exchange manager's Azure client should - retry a request. - - ``10`` - - Azure Blob Storage - * - ``exchange.hdfs.block-size`` - - Block :ref:`data size ` for HDFS storage. - - ``4MB`` - - HDFS - * - ``hdfs.config.resources`` - - Comma-separated list of paths to HDFS configuration files, for example ``/etc/hdfs-site.xml``. - The files must exist on all nodes in the Trino cluster. - - - - HDFS -``` +:::{list-table} Exchange manager configuration properties +:widths: 30, 50, 20, 30 +:header-rows: 1 + +* - Property name + - Description + - Default value + - Supported filesystem +* - `exchange.base-directories` + - Comma-separated list of URI locations that the exchange manager uses to + store spooling data. + - + - Any +* - `exchange.sink-buffer-pool-min-size` + - The minimum buffer pool size for an exchange sink. The larger the buffer + pool size, the larger the write parallelism and memory usage. + - `10` + - Any +* - `exchange.sink-buffers-per-partition` + - The number of buffers per partition in the buffer pool. The larger the + buffer pool size, the larger the write parallelism and memory usage. + - `2` + - Any +* - `exchange.sink-max-file-size` + - Max [data size](prop-type-data-size) of files written by exchange sinks. + - ``1GB`` + - Any +* - `exchange.source-concurrent-readers` + - Number of concurrent readers to read from spooling storage. The larger the + number of concurrent readers, the larger the read parallelism and memory + usage. + - `4` + - Any +* - `exchange.s3.aws-access-key` + - AWS access key to use. Required for a connection to AWS S3 and GCS, can be + ignored for other S3 storage systems. + - + - AWS S3, GCS +* - `exchange.s3.aws-secret-key` + - AWS secret key to use. Required for a connection to AWS S3 and GCS, can be + ignored for other S3 storage systems. + - + - AWS S3, GCS +* - `exchange.s3.iam-role` + - IAM role to assume. + - + - AWS S3, GCS +* - `exchange.s3.external-id` + - External ID for the IAM role trust policy. + - + - AWS S3, GCS +* - `exchange.s3.region` + - Region of the S3 bucket. + - + - AWS S3, GCS +* - `exchange.s3.endpoint` + - S3 storage endpoint server if using an S3-compatible storage system that + is not AWS. If using AWS S3, this can be ignored. If using GCS, set it + to `https://storage.googleapis.com`. + - + - Any S3-compatible storage +* - `exchange.s3.max-error-retries` + - Maximum number of times the exchange manager's S3 client should retry + a request. + - `10` + - Any S3-compatible storage +* - `exchange.s3.path-style-access` + - Enables using [path-style access](https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html#path-style-access) + for all requests to S3. + - `false` + - Any S3-compatible storage +* - `exchange.s3.upload.part-size` + - Part [data size](prop-type-data-size) for S3 multi-part upload. + - `5MB` + - Any S3-compatible storage +* - `exchange.gcs.json-key-file-path` + - Path to the JSON file that contains your Google Cloud Platform service + account key. Not to be set together with `exchange.gcs.json-key` + - + - GCS +* - `exchange.gcs.json-key` + - Your Google Cloud Platform service account key in JSON format. Not to be set + together with `exchange.gcs.json-key-file-path` + - + - GCS +* - `exchange.azure.connection-string` + - Connection string used to access the spooling container. + - + - Azure Blob Storage +* - `exchange.azure.block-size` + - Block [data size](prop-type-data-size) for Azure block blob parallel upload. + - `4MB` + - Azure Blob Storage +* - `exchange.azure.max-error-retries` + - Maximum number of times the exchange manager's Azure client should + retry a request. + - `10` + - Azure Blob Storage +* - `exchange.hdfs.block-size` + - Block [data size](prop-type-data-size) for HDFS storage. + - `4MB` + - HDFS +* - `hdfs.config.resources` + - Comma-separated list of paths to HDFS configuration files, for example + `/etc/hdfs-site.xml`. The files must exist on all nodes in the Trino + cluster. + - + - HDFS +::: It is recommended to set the `exchange.compression-enabled` property to `true` in the cluster's `config.properties` file, to reduce the exchange diff --git a/docs/src/main/sphinx/admin/resource-groups.md b/docs/src/main/sphinx/admin/resource-groups.md index ff6e1e497d0b..241b5b552c13 100644 --- a/docs/src/main/sphinx/admin/resource-groups.md +++ b/docs/src/main/sphinx/admin/resource-groups.md @@ -56,39 +56,37 @@ Trino clusters to be stored in the same database if required. The configuration is reloaded from the database every second, and the changes are reflected automatically for incoming queries. -```{eval-rst} -.. list-table:: Database resource group manager properties - :widths: 40, 50, 10 - :header-rows: 1 - - * - Property name - - Description - - Default value - * - ``resource-groups.config-db-url`` - - Database URL to load configuration from. - - ``none`` - * - ``resource-groups.config-db-user`` - - Database user to connect with. - - ``none`` - * - ``resource-groups.config-db-password`` - - Password for database user to connect with. - - ``none`` - * - ``resource-groups.max-refresh-interval`` - - The maximum time period for which the cluster will continue to accept - queries after refresh failures, causing configuration to become stale. - - ``1h`` - * - ``resource-groups.refresh-interval`` - - How often the cluster reloads from the database - - ``1s`` - * - ``resource-groups.exact-match-selector-enabled`` - - Setting this flag enables usage of an additional - ``exact_match_source_selectors`` table to configure resource group - selection rules defined exact name based matches for source, environment - and query type. By default, the rules are only loaded from the - ``selectors`` table, with a regex-based filter for ``source``, among - other filters. - - ``false`` -``` +:::{list-table} Database resource group manager properties +:widths: 40, 50, 10 +:header-rows: 1 + +* - Property name + - Description + - Default value +* - `resource-groups.config-db-url` + - Database URL to load configuration from. + - `none` +* - `resource-groups.config-db-user` + - Database user to connect with. + - `none` +* - `resource-groups.config-db-password` + - Password for database user to connect with. + - `none` +* - `resource-groups.max-refresh-interval` + - The maximum time period for which the cluster will continue to accept + queries after refresh failures, causing configuration to become stale. + - `1h` +* - `resource-groups.refresh-interval` + - How often the cluster reloads from the database + - `1s` +* - `resource-groups.exact-match-selector-enabled` + - Setting this flag enables usage of an additional + `exact_match_source_selectors` table to configure resource group selection + rules defined exact name based matches for source, environment and query + type. By default, the rules are only loaded from the `selectors` table, with + a regex-based filter for `source`, among other filters. + - `false` +::: ## Resource group properties diff --git a/docs/src/main/sphinx/client/cli.md b/docs/src/main/sphinx/client/cli.md index dd066f5af81d..2d09a0eabaa2 100644 --- a/docs/src/main/sphinx/client/cli.md +++ b/docs/src/main/sphinx/client/cli.md @@ -125,84 +125,83 @@ trino:tiny> Many other options are available to further configure the CLI in interactive mode: -```{eval-rst} -.. list-table:: - :widths: 40, 60 - :header-rows: 1 - - * - Option - - Description - * - ``--catalog`` - - Sets the default catalog. You can change the default catalog and schema - with :doc:`/sql/use`. - * - ``--client-info`` - - Adds arbitrary text as extra information about the client. - * - ``--client-request-timeout`` - - Sets the duration for query processing, after which, the client request is - terminated. Defaults to ``2m``. - * - ``--client-tags`` - - Adds extra tags information about the client and the CLI user. Separate - multiple tags with commas. The tags can be used as input for - :doc:`/admin/resource-groups`. - * - ``--debug`` - - Enables display of debug information during CLI usage for - :ref:`cli-troubleshooting`. Displays more information about query - processing statistics. - * - ``--disable-auto-suggestion`` - - Disables autocomplete suggestions. - * - ``--disable-compression`` - - Disables compression of query results. - * - ``--editing-mode`` - - Sets key bindings in the CLI to be compatible with VI or - EMACS editors. Defaults to ``EMACS``. - * - ``--http-proxy`` - - Configures the URL of the HTTP proxy to connect to Trino. - * - ``--history-file`` - - Path to the :ref:`history file `. Defaults to ``~/.trino_history``. - * - ``--network-logging`` - - Configures the level of detail provided for network logging of the CLI. - Defaults to ``NONE``, other options are ``BASIC``, ``HEADERS``, or - ``BODY``. - * - ``--output-format-interactive=`` - - Specify the :ref:`format ` to use - for printing query results. Defaults to ``ALIGNED``. - * - ``--pager=`` - - Path to the pager program used to display the query results. Set to - an empty value to completely disable pagination. Defaults to ``less`` - with a carefully selected set of options. - * - ``--no-progress`` - - Do not show query processing progress. - * - ``--password`` - - Prompts for a password. Use if your Trino server requires password - authentication. You can set the ``TRINO_PASSWORD`` environment variable - with the password value to avoid the prompt. For more information, see :ref:`cli-username-password-auth`. - * - ``--schema`` - - Sets the default schema. You can change the default catalog and schema - with :doc:`/sql/use`. - * - ``--server`` - - The HTTP/HTTPS address and port of the Trino coordinator. The port must be - set to the port the Trino coordinator is listening for connections on. - Trino server location defaults to ``http://localhost:8080``. - Can only be set if URL is not specified. - * - ``--session`` - - Sets one or more :ref:`session properties - `. Property can be used multiple times with - the format ``session_property_name=value``. - * - ``--socks-proxy`` - - Configures the URL of the SOCKS proxy to connect to Trino. - * - ``--source`` - - Specifies the name of the application or source connecting to Trino. - Defaults to ``trino-cli``. The value can be used as input for - :doc:`/admin/resource-groups`. - * - ``--timezone`` - - Sets the time zone for the session using the `time zone name - `_. Defaults - to the timezone set on your workstation. - * - ``--user`` - - Sets the username for :ref:`cli-username-password-auth`. Defaults to your - operating system username. You can override the default username, - if your cluster uses a different username or authentication mechanism. -``` +:::{list-table} +:widths: 40, 60 +:header-rows: 1 + +* - Option + - Description +* - `--catalog` + - Sets the default catalog. You can change the default catalog and schema with + [](/sql/use). +* - `--client-info` + - Adds arbitrary text as extra information about the client. +* - `--client-request-timeout` + - Sets the duration for query processing, after which, the client request is + terminated. Defaults to `2m`. +* - `--client-tags` + - Adds extra tags information about the client and the CLI user. Separate + multiple tags with commas. The tags can be used as input for + [](/admin/resource-groups). +* - `--debug` + - Enables display of debug information during CLI usage for + [](cli-troubleshooting). Displays more information about query + processing statistics. +* - `--disable-auto-suggestion` + - Disables autocomplete suggestions. +* - `--disable-compression` + - Disables compression of query results. +* - `--editing-mode` + - Sets key bindings in the CLI to be compatible with VI or + EMACS editors. Defaults to `EMACS`. +* - `--http-proxy` + - Configures the URL of the HTTP proxy to connect to Trino. +* - `--history-file` + - Path to the [history file](cli-history). Defaults to `~/.trino_history`. +* - `--network-logging` + - Configures the level of detail provided for network logging of the CLI. + Defaults to `NONE`, other options are `BASIC`, `HEADERS`, or `BODY`. +* - `--output-format-interactive=` + - Specify the [format](cli-output-format) to use for printing query results. + Defaults to `ALIGNED`. +* - `--pager=` + - Path to the pager program used to display the query results. Set to an empty + value to completely disable pagination. Defaults to `less` with a carefully + selected set of options. +* - `--no-progress` + - Do not show query processing progress. +* - `--password` + - Prompts for a password. Use if your Trino server requires password + authentication. You can set the `TRINO_PASSWORD` environment variable with + the password value to avoid the prompt. For more information, see + [](cli-username-password-auth). +* - `--schema` + - Sets the default schema. You can change the default catalog and schema + with [](/sql/use). +* - `--server` + - The HTTP/HTTPS address and port of the Trino coordinator. The port must be + set to the port the Trino coordinator is listening for connections on. Trino + server location defaults to `http://localhost:8080`. Can only be set if URL + is not specified. +* - `--session` + - Sets one or more [session properties](session-properties-definition). + Property can be used multiple times with the format + `session_property_name=value`. +* - `--socks-proxy` + - Configures the URL of the SOCKS proxy to connect to Trino. +* - `--source` + - Specifies the name of the application or source connecting to Trino. + Defaults to `trino-cli`. The value can be used as input for + [](/admin/resource-groups). +* - `--timezone` + - Sets the time zone for the session using the [time zone name]( + ). Defaults to + the timezone set on your workstation. +* - `--user` + - Sets the username for [](cli-username-password-auth). Defaults to your + operating system username. You can override the default username, if your + cluster uses a different username or authentication mechanism. +::: Most of the options can also be set as parameters in the URL. This means a JDBC URL can be used in the CLI after removing the `jdbc:` prefix. @@ -236,41 +235,40 @@ recognizes these certificates. Use the options from the following table to further configure TLS and certificate usage: -```{eval-rst} -.. list-table:: - :widths: 40, 60 - :header-rows: 1 - - * - Option - - Description - * - ``--insecure`` - - Skip certificate validation when connecting with TLS/HTTPS (should only be - used for debugging). - * - ``--keystore-path`` - - The location of the Java Keystore file that contains the certificate of - the server to connect with TLS. - * - ``--keystore-password`` - - The password for the keystore. This must match the password you specified - when creating the keystore. - * - ``--keystore-type`` - - Determined by the keystore file format. The default keystore type is JKS. - This advanced option is only necessary if you use a custom Java - Cryptography Architecture (JCA) provider implementation. - * - ``--truststore-password`` - - The password for the truststore. This must match the password you - specified when creating the truststore. - * - ``--truststore-path`` - - The location of the Java truststore file that will be used to secure TLS. - * - ``--truststore-type`` - - Determined by the truststore file format. The default keystore type is - JKS. This advanced option is only necessary if you use a custom Java - Cryptography Architecture (JCA) provider implementation. - * - ``--use-system-truststore`` - - Verify the server certificate using the system truststore of the - operating system. Windows and macOS are supported. For other operating - systems, the default Java truststore is used. The truststore type can - be overridden using ``--truststore-type``. -``` +:::{list-table} +:widths: 40, 60 +:header-rows: 1 + +* - Option + - Description +* - `--insecure` + - Skip certificate validation when connecting with TLS/HTTPS (should only be + used for debugging). +* - `--keystore-path` + - The location of the Java Keystore file that contains the certificate of the + server to connect with TLS. +* - `--keystore-password` + - The password for the keystore. This must match the password you specified + when creating the keystore. +* - `--keystore-type` + - Determined by the keystore file format. The default keystore type is JKS. + This advanced option is only necessary if you use a custom Java Cryptography + Architecture (JCA) provider implementation. +* - `--truststore-password` + - The password for the truststore. This must match the password you specified + when creating the truststore. +* - `--truststore-path` + - The location of the Java truststore file that will be used to secure TLS. +* - `--truststore-type` + - Determined by the truststore file format. The default keystore type is JKS. + This advanced option is only necessary if you use a custom Java Cryptography + Architecture (JCA) provider implementation. +* - `--use-system-truststore` + - Verify the server certificate using the system truststore of the operating + system. Windows and macOS are supported. For other operating systems, the + default Java truststore is used. The truststore type can be overridden using + `--truststore-type`. +::: (cli-authentication)= @@ -340,20 +338,19 @@ The detailed behavior is as follows: Use the following CLI arguments to connect to a cluster that uses {doc}`certificate authentication `. -```{eval-rst} -.. list-table:: CLI options for certificate authentication - :widths: 35 65 - :header-rows: 1 - - * - Option - - Description - * - ``--keystore-path=`` - - Absolute or relative path to a :doc:`PEM ` or - :doc:`JKS ` file, which must contain a certificate - that is trusted by the Trino cluster you are connecting to. - * - ``--keystore-password=`` - - Only required if the keystore has a password. -``` +:::{list-table} CLI options for certificate authentication +:widths: 35 65 +:header-rows: 1 + +* - Option + - Description +* - `--keystore-path=` + - Absolute or relative path to a [PEM](/security/inspect-pem) or + [JKS](/security/inspect-jks) file, which must contain a certificate + that is trusted by the Trino cluster you are connecting to. +* - `--keystore-password=` + - Only required if the keystore has a password. +::: The truststore related options are independent of client certificate authentication with the CLI; instead, they control the client's trust of the @@ -394,30 +391,28 @@ through {doc}`TLS and HTTPS `. The following table lists the available options for Kerberos authentication: -```{eval-rst} -.. list-table:: CLI options for Kerberos authentication - :widths: 40, 60 - :header-rows: 1 - - * - Option - - Description - * - ``--krb5-config-path`` - - Path to Kerberos configuration files. - * - ``--krb5-credential-cache-path`` - - Kerberos credential cache path. - * - ``--krb5-disable-remote-service-hostname-canonicalization`` - - Disable service hostname canonicalization using the DNS reverse lookup. - * - ``--krb5-keytab-path`` - - The location of the keytab that can be used to authenticate the principal - specified by ``--krb5-principal``. - * - ``--krb5-principal`` - - The principal to use when authenticating to the coordinator. - * - ``--krb5-remote-service-name`` - - Trino coordinator Kerberos service name. - * - ``--krb5-service-principal-pattern`` - - Remote kerberos service principal pattern. Defaults to - ``${SERVICE}@${HOST}``. -``` +:::{list-table} CLI options for Kerberos authentication +:widths: 40, 60 +:header-rows: 1 + +* - Option + - Description +* - `--krb5-config-path` + - Path to Kerberos configuration files. +* - `--krb5-credential-cache-path` + - Kerberos credential cache path. +* - `--krb5-disable-remote-service-hostname-canonicalization` + - Disable service hostname canonicalization using the DNS reverse lookup. +* - `--krb5-keytab-path` + - The location of the keytab that can be used to authenticate the principal + specified by `--krb5-principal`. +* - `--krb5-principal` + - The principal to use when authenticating to the coordinator. +* - `--krb5-remote-service-name` + - Trino coordinator Kerberos service name. +* - `--krb5-service-principal-pattern` + - Remote kerberos service principal pattern. Defaults to `${SERVICE}@${HOST}`. +::: (cli-kerberos-debug)= @@ -520,27 +515,26 @@ other formats and redirect the output to a file. The following options are available to further configure the CLI in batch mode: -```{eval-rst} -.. list-table:: - :widths: 40, 60 - :header-rows: 1 - - * - Option - - Description - * - ``--execute=`` - - Execute specified statements and exit. - * - ``-f``, ``--file=`` - - Execute statements from file and exit. - * - ``--ignore-errors`` - - Continue processing in batch mode when an error occurs. Default is to - exit immediately. - * - ``--output-format=`` - - Specify the :ref:`format ` to use - for printing query results. Defaults to ``CSV``. - * - ``--progress`` - - Show query progress in batch mode. It does not affect the output, - which, for example can be safely redirected to a file. -``` +:::{list-table} +:widths: 40, 60 +:header-rows: 1 + +* - Option + - Description +* - `--execute=` + - Execute specified statements and exit. +* - `-f`, `--file=` + - Execute statements from file and exit. +* - `--ignore-errors` + - Continue processing in batch mode when an error occurs. Default is to exit + immediately. +* - `--output-format=` + - Specify the [format](cli-output-format) to use for printing query results. + Defaults to `CSV`. +* - `--progress` + - Show query progress in batch mode. It does not affect the output, which, for + example can be safely redirected to a file. +::: ### Examples @@ -615,41 +609,40 @@ The available options shown in the following table must be entered in uppercase. The default value is `ALIGNED` in interactive mode, and `CSV` in non-interactive mode. -```{eval-rst} -.. list-table:: Output format options - :widths: 25, 75 - :header-rows: 1 - - * - Option - - Description - * - ``CSV`` - - Comma-separated values, each value quoted. No header row. - * - ``CSV_HEADER`` - - Comma-separated values, quoted with header row. - * - ``CSV_UNQUOTED`` - - Comma-separated values without quotes. - * - ``CSV_HEADER_UNQUOTED`` - - Comma-separated values with header row but no quotes. - * - ``TSV`` - - Tab-separated values. - * - ``TSV_HEADER`` - - Tab-separated values with header row. - * - ``JSON`` - - Output rows emitted as JSON objects with name-value pairs. - * - ``ALIGNED`` - - Output emitted as an ASCII character table with values. - * - ``VERTICAL`` - - Output emitted as record-oriented top-down lines, one per value. - * - ``AUTO`` - - Same as ``ALIGNED`` if output would fit the current terminal width, - and ``VERTICAL`` otherwise. - * - ``MARKDOWN`` - - Output emitted as a Markdown table. - * - ``NULL`` - - Suppresses normal query results. This can be useful during development - to test a query's shell return code or to see whether it results in - error messages. -``` +:::{list-table} Output format options +:widths: 25, 75 +:header-rows: 1 + +* - Option + - Description +* - `CSV` + - Comma-separated values, each value quoted. No header row. +* - `CSV_HEADER` + - Comma-separated values, quoted with header row. +* - `CSV_UNQUOTED` + - Comma-separated values without quotes. +* - `CSV_HEADER_UNQUOTED` + - Comma-separated values with header row but no quotes. +* - `TSV` + - Tab-separated values. +* - `TSV_HEADER` + - Tab-separated values with header row. +* - `JSON` + - Output rows emitted as JSON objects with name-value pairs. +* - `ALIGNED` + - Output emitted as an ASCII character table with values. +* - `VERTICAL` + - Output emitted as record-oriented top-down lines, one per value. +* - `AUTO` + - Same as `ALIGNED` if output would fit the current terminal width, + and `VERTICAL` otherwise. +* - `MARKDOWN` + - Output emitted as a Markdown table. +* - `NULL` + - Suppresses normal query results. This can be useful during development to + test a query's shell return code or to see whether it results in error + messages. +::: (cli-troubleshooting)= diff --git a/docs/src/main/sphinx/client/jdbc.md b/docs/src/main/sphinx/client/jdbc.md index 76d4710d34a3..ab0535ad703e 100644 --- a/docs/src/main/sphinx/client/jdbc.md +++ b/docs/src/main/sphinx/client/jdbc.md @@ -117,139 +117,132 @@ may not be specified using both methods. ## Parameter reference -```{eval-rst} -.. list-table:: - :widths: 35, 65 - :header-rows: 1 - - * - Name - - Description - * - ``user`` - - Username to use for authentication and authorization. - * - ``password`` - - Password to use for LDAP authentication. - * - ``sessionUser`` - - Session username override, used for impersonation. - * - ``socksProxy`` - - SOCKS proxy host and port. Example: ``localhost:1080`` - * - ``httpProxy`` - - HTTP proxy host and port. Example: ``localhost:8888`` - * - ``clientInfo`` - - Extra information about the client. - * - ``clientTags`` - - Client tags for selecting resource groups. Example: ``abc,xyz`` - * - ``traceToken`` - - Trace token for correlating requests across systems. - * - ``source`` - - Source name for the Trino query. This parameter should be used in - preference to ``ApplicationName``. Thus, it takes precedence over - ``ApplicationName`` and/or ``applicationNamePrefix``. - * - ``applicationNamePrefix`` - - Prefix to append to any specified ``ApplicationName`` client info - property, which is used to set the source name for the Trino query if the - ``source`` parameter has not been set. If neither this property nor - ``ApplicationName`` or ``source`` are set, the source name for the query - is ``trino-jdbc``. - * - ``accessToken`` - - :doc:`JWT ` access token for token based authentication. - * - ``SSL`` - - Set ``true`` to specify using TLS/HTTPS for connections. - * - ``SSLVerification`` - - The method of TLS verification. There are three modes: ``FULL`` - (default), ``CA`` and ``NONE``. For ``FULL``, the normal TLS verification - is performed. For ``CA``, only the CA is verified but hostname mismatch - is allowed. For ``NONE``, there is no verification. - * - ``SSLKeyStorePath`` - - Use only when connecting to a Trino cluster that has :doc:`certificate - authentication ` enabled. Specifies the path to a - :doc:`PEM ` or :doc:`JKS ` - file, which must contain a certificate that is trusted by the Trino - cluster you connect to. - * - ``SSLKeyStorePassword`` - - The password for the KeyStore, if any. - * - ``SSLKeyStoreType`` - - The type of the KeyStore. The default type is provided by the Java - ``keystore.type`` security property or ``jks`` if none exists. - * - ``SSLTrustStorePath`` - - The location of the Java TrustStore file to use to validate HTTPS server - certificates. - * - ``SSLTrustStorePassword`` - - The password for the TrustStore. - * - ``SSLTrustStoreType`` - - The type of the TrustStore. The default type is provided by the Java - ``keystore.type`` security property or ``jks`` if none exists. - * - ``SSLUseSystemTrustStore`` - - Set ``true`` to automatically use the system TrustStore based on the - operating system. The supported OSes are Windows and macOS. For Windows, - the ``Windows-ROOT`` TrustStore is selected. For macOS, the - ``KeychainStore`` TrustStore is selected. For other OSes, the default - Java TrustStore is loaded. The TrustStore specification can be overridden - using ``SSLTrustStoreType``. - * - ``hostnameInCertificate`` - - Expected hostname in the certificate presented by the Trino server. Only - applicable with full SSL verification enabled. - * - ``KerberosRemoteServiceName`` - - Trino coordinator Kerberos service name. This parameter is required for - Kerberos authentication. - * - ``KerberosPrincipal`` - - The principal to use when authenticating to the Trino coordinator. - * - ``KerberosUseCanonicalHostname`` - - Use the canonical hostname of the Trino coordinator for the Kerberos - service principal by first resolving the hostname to an IP address and - then doing a reverse DNS lookup for that IP address. This is enabled by - default. - * - ``KerberosServicePrincipalPattern`` - - Trino coordinator Kerberos service principal pattern. The default is - ``${SERVICE}@${HOST}``. ``${SERVICE}`` is replaced with the value of - ``KerberosRemoteServiceName`` and ``${HOST}`` is replaced with the - hostname of the coordinator (after canonicalization if enabled). - * - ``KerberosConfigPath`` - - Kerberos configuration file. - * - ``KerberosKeytabPath`` - - Kerberos keytab file. - * - ``KerberosCredentialCachePath`` - - Kerberos credential cache. - * - ``KerberosDelegation`` - - Set to ``true`` to use the token from an existing Kerberos context. This - allows client to use Kerberos authentication without passing the Keytab - or credential cache. Defaults to ``false``. - * - ``extraCredentials`` - - Extra credentials for connecting to external services, specified as a - list of key-value pairs. For example, ``foo:bar;abc:xyz`` creates the - credential named ``abc`` with value ``xyz`` and the credential named - ``foo`` with value ``bar``. - * - ``roles`` - - Authorization roles to use for catalogs, specified as a list of key-value - pairs for the catalog and role. For example, - ``catalog1:roleA;catalog2:roleB`` sets ``roleA`` for ``catalog1`` and - ``roleB`` for ``catalog2``. - * - ``sessionProperties`` - - Session properties to set for the system and for catalogs, specified as a - list of key-value pairs. For example, ``abc:xyz;example.foo:bar`` sets - the system property ``abc`` to the value ``xyz`` and the ``foo`` property - for catalog ``example`` to the value ``bar``. - * - ``externalAuthentication`` - - Set to true if you want to use external authentication via - :doc:`/security/oauth2`. Use a local web browser to authenticate with an - identity provider (IdP) that has been configured for the Trino - coordinator. - * - ``externalAuthenticationTokenCache`` - - Allows the sharing of external authentication tokens between different - connections for the same authenticated user until the cache is - invalidated, such as when a client is restarted or when the classloader - reloads the JDBC driver. This is disabled by default, with a value of - ``NONE``. To enable, set the value to ``MEMORY``. If the JDBC driver is - used in a shared mode by different users, the first registered token is - stored and authenticates all users. - * - ``disableCompression`` - - Whether compression should be enabled. - * - ``assumeLiteralUnderscoreInMetadataCallsForNonConformingClients`` - - When enabled, the name patterns passed to ``DatabaseMetaData`` methods are - treated as underscores. You can use this as a workaround for - applications that do not escape schema or table names when passing them - to ``DatabaseMetaData`` methods as schema or table name patterns. - * - ``timezone`` - - Sets the time zone for the session using the `time zone passed - `_. Defaults - to the timezone of the JVM running the JDBC driver. -``` +:::{list-table} +:widths: 35, 65 +:header-rows: 1 + +* - Name + - Description +* - `user` + - Username to use for authentication and authorization. +* - `password` + - Password to use for LDAP authentication. +* - `sessionUser` + - Session username override, used for impersonation. +* - `socksProxy` + - SOCKS proxy host and port. Example: `localhost:1080` +* - `httpProxy` + - HTTP proxy host and port. Example: `localhost:8888` +* - `clientInfo` + - Extra information about the client. +* - `clientTags` + - Client tags for selecting resource groups. Example: `abc,xyz` +* - `traceToken` + - Trace token for correlating requests across systems. +* - `source` + - Source name for the Trino query. This parameter should be used in preference + to `ApplicationName`. Thus, it takes precedence over `ApplicationName` + and/or `applicationNamePrefix`. +* - `applicationNamePrefix` + - Prefix to append to any specified `ApplicationName` client info property, + which is used to set the source name for the Trino query if the `source` + parameter has not been set. If neither this property nor `ApplicationName` + or `source` are set, the source name for the query is `trino-jdbc`. +* - `accessToken` + - [JWT](/security/jwt) access token for token based authentication. +* - `SSL` + - Set `true` to specify using TLS/HTTPS for connections. +* - `SSLVerification` + - The method of TLS verification. There are three modes: `FULL` + (default), `CA` and `NONE`. For `FULL`, the normal TLS verification + is performed. For `CA`, only the CA is verified but hostname mismatch + is allowed. For `NONE`, there is no verification. +* - `SSLKeyStorePath` + - Use only when connecting to a Trino cluster that has [certificate + authentication](/security/certificate) enabled. Specifies the path to a + [PEM](/security/inspect-pem) or [JKS](/security/inspect-jks) file, which must + contain a certificate that is trusted by the Trino cluster you connect to. +* - `SSLKeyStorePassword` + - The password for the KeyStore, if any. +* - `SSLKeyStoreType` + - The type of the KeyStore. The default type is provided by the Java + `keystore.type` security property or `jks` if none exists. +* - `SSLTrustStorePath` + - The location of the Java TrustStore file to use to validate HTTPS server + certificates. +* - `SSLTrustStorePassword` + - The password for the TrustStore. +* - `SSLTrustStoreType` + - The type of the TrustStore. The default type is provided by the Java + `keystore.type` security property or `jks` if none exists. +* - `SSLUseSystemTrustStore` + - Set `true` to automatically use the system TrustStore based on the operating + system. The supported OSes are Windows and macOS. For Windows, the + `Windows-ROOT` TrustStore is selected. For macOS, the `KeychainStore` + TrustStore is selected. For other OSes, the default Java TrustStore is + loaded. The TrustStore specification can be overridden using + `SSLTrustStoreType`. +* - `hostnameInCertificate` + - Expected hostname in the certificate presented by the Trino server. Only + applicable with full SSL verification enabled. +* - `KerberosRemoteServiceName` + - Trino coordinator Kerberos service name. This parameter is required for + Kerberos authentication. +* - `KerberosPrincipal` + - The principal to use when authenticating to the Trino coordinator. +* - `KerberosUseCanonicalHostname` + - Use the canonical hostname of the Trino coordinator for the Kerberos service + principal by first resolving the hostname to an IP address and then doing a + reverse DNS lookup for that IP address. This is enabled by default. +* - `KerberosServicePrincipalPattern` + - Trino coordinator Kerberos service principal pattern. The default is + `${SERVICE}@${HOST}`. `${SERVICE}` is replaced with the value of + `KerberosRemoteServiceName` and `${HOST}` is replaced with the hostname of + the coordinator (after canonicalization if enabled). +* - `KerberosConfigPath` + - Kerberos configuration file. +* - `KerberosKeytabPath` + - Kerberos keytab file. +* - `KerberosCredentialCachePath` + - Kerberos credential cache. +* - `KerberosDelegation` + - Set to `true` to use the token from an existing Kerberos context. This + allows client to use Kerberos authentication without passing the Keytab or + credential cache. Defaults to `false`. +* - `extraCredentials` + - Extra credentials for connecting to external services, specified as a list + of key-value pairs. For example, `foo:bar;abc:xyz` creates the credential + named `abc` with value `xyz` and the credential named `foo` with value + `bar`. +* - `roles` + - Authorization roles to use for catalogs, specified as a list of key-value + pairs for the catalog and role. For example, `catalog1:roleA;catalog2:roleB` + sets `roleA` for `catalog1` and `roleB` for `catalog2`. +* - `sessionProperties` + - Session properties to set for the system and for catalogs, specified as a + list of key-value pairs. For example, `abc:xyz;example.foo:bar` sets the + system property `abc` to the value `xyz` and the `foo` property for catalog + `example` to the value `bar`. +* - `externalAuthentication` + - Set to true if you want to use external authentication via + [](/security/oauth2). Use a local web browser to authenticate with an + identity provider (IdP) that has been configured for the Trino coordinator. +* - `externalAuthenticationTokenCache` + - Allows the sharing of external authentication tokens between different + connections for the same authenticated user until the cache is invalidated, + such as when a client is restarted or when the classloader reloads the JDBC + driver. This is disabled by default, with a value of `NONE`. To enable, set + the value to `MEMORY`. If the JDBC driver is used in a shared mode by + different users, the first registered token is stored and authenticates all + users. +* - `disableCompression` + - Whether compression should be enabled. +* - `assumeLiteralUnderscoreInMetadataCallsForNonConformingClients` + - When enabled, the name patterns passed to `DatabaseMetaData` methods are + treated as underscores. You can use this as a workaround for applications + that do not escape schema or table names when passing them to + `DatabaseMetaData` methods as schema or table name patterns. ::: +* - `timezone` + - Sets the time zone for the session using the [time zone + passed](https://docs.oracle.com/en/java/javase/17/docs/api/java.base/java/time/ZoneId.html#of(java.lang.String)). + Defaults to the timezone of the JVM running the JDBC driver. diff --git a/docs/src/main/sphinx/connector/mariadb.md b/docs/src/main/sphinx/connector/mariadb.md index 60471a62fe53..2f78a0d702d5 100644 --- a/docs/src/main/sphinx/connector/mariadb.md +++ b/docs/src/main/sphinx/connector/mariadb.md @@ -346,6 +346,26 @@ FROM The connector includes a number of performance improvements, detailed in the following sections. +(mariadb-table-statistics)= +### Table statistics + +The MariaDB connector can use [table and column +statistics](/optimizer/statistics) for [cost based +optimizations](/optimizer/cost-based-optimizations) to improve query processing +performance based on the actual data in the data source. + +The statistics are collected by MariaDB and retrieved by the connector. + +To collect statistics for a table, execute the following statement in +MariaDB. + +```text +ANALYZE TABLE table_name; +``` + +Refer to [MariaDB documentation](https://mariadb.com/kb/en/analyze-table/) for +additional information. + (mariadb-pushdown)= ### Pushdown diff --git a/docs/src/main/sphinx/functions.md b/docs/src/main/sphinx/functions.md index 91885e47236f..3478e1a92786 100644 --- a/docs/src/main/sphinx/functions.md +++ b/docs/src/main/sphinx/functions.md @@ -4,14 +4,24 @@ This section describes the built-in SQL functions and operators supported by Trino. They allow you to implement complex capabilities and behavior of the queries executed by Trino operating on the underlying data sources. -If you are looking for a specific function or operator, see the {doc}`full -alphabetical list` or the {doc}`full list by -topic`. Using {doc}`SHOW FUNCTIONS -` returns a list of all available functions, including -custom functions, with all supported arguments and a short description. +Refer to the following sections for further details: -Also see the {doc}`SQL data types` -and the {doc}`SQL statement and syntax reference`. +* [SQL data types and other general aspects](/language) +* [SQL statement and syntax reference](/sql) + +## Functions by name + +If you are looking for a specific function or operator by name use +[](/sql/show-functions), or refer the to the following resources: + +:::{toctree} +:maxdepth: 1 + +functions/list +functions/list-by-topic +::: + +## Functions per topic ```{toctree} :maxdepth: 1 @@ -47,6 +57,4 @@ T-Digest URL UUID Window -functions/list -functions/list-by-topic ``` diff --git a/docs/src/main/sphinx/functions/comparison.md b/docs/src/main/sphinx/functions/comparison.md index 47f2585cd20e..bb137c267aac 100644 --- a/docs/src/main/sphinx/functions/comparison.md +++ b/docs/src/main/sphinx/functions/comparison.md @@ -4,28 +4,27 @@ ## Comparison operators -```{eval-rst} -.. list-table:: - :widths: 30, 70 - :header-rows: 1 - - * - Operator - - Description - * - ``<`` - - Less than - * - ``>`` - - Greater than - * - ``<=`` - - Less than or equal to - * - ``>=`` - - Greater than or equal to - * - ``=`` - - Equal - * - ``<>`` - - Not equal - * - ``!=`` - - Not equal (non-standard but popular syntax) -``` +:::{list-table} +:widths: 30, 70 +:header-rows: 1 + +* - Operator + - Description +* - `<` + - Less than +* - `>` + - Greater than +* - `<=` + - Less than or equal to +* - `>=` + - Greater than or equal to +* - `=` + - Equal +* - `<>` + - Not equal +* - `!=` + - Not equal (non-standard but popular syntax) +::: (range-operator)= @@ -178,27 +177,26 @@ SELECT 42 >= SOME (SELECT 41 UNION ALL SELECT 42 UNION ALL SELECT 43); -- true Here are the meanings of some quantifier and comparison operator combinations: -```{eval-rst} -.. list-table:: - :widths: 40, 60 - :header-rows: 1 - - * - Expression - - Meaning - * - ``A = ALL (...)`` - - Evaluates to ``true`` when ``A`` is equal to all values. - * - ``A <> ALL (...)`` - - Evaluates to ``true`` when ``A`` doesn't match any value. - * - ``A < ALL (...)`` - - Evaluates to ``true`` when ``A`` is smaller than the smallest value. - * - ``A = ANY (...)`` - - Evaluates to ``true`` when ``A`` is equal to any of the values. This form - is equivalent to ``A IN (...)``. - * - ``A <> ANY (...)`` - - Evaluates to ``true`` when ``A`` doesn't match one or more values. - * - ``A < ANY (...)`` - - Evaluates to ``true`` when ``A`` is smaller than the biggest value. -``` +:::{list-table} +:widths: 40, 60 +:header-rows: 1 + +* - Expression + - Meaning +* - `A = ALL (...)` + - Evaluates to `true` when `A` is equal to all values. +* - `A <> ALL (...)` + - Evaluates to `true` when `A` doesn't match any value. +* - `A < ALL (...)` + - Evaluates to `true` when `A` is smaller than the smallest value. +* - `A = ANY (...)` + - Evaluates to `true` when `A` is equal to any of the values. This form + is equivalent to `A IN (...)`. +* - `A <> ANY (...)` + - Evaluates to `true` when `A` doesn't match one or more values. +* - `A < ANY (...)` + - Evaluates to `true` when `A` is smaller than the biggest value. +::: `ANY` and `SOME` have the same meaning and can be used interchangeably. diff --git a/docs/src/main/sphinx/functions/conversion.md b/docs/src/main/sphinx/functions/conversion.md index e12b1a18e91a..7de546ab7ed8 100644 --- a/docs/src/main/sphinx/functions/conversion.md +++ b/docs/src/main/sphinx/functions/conversion.md @@ -62,42 +62,41 @@ SELECT format_number(1000000); -- '1M' The `parse_data_size` function supports the following units: -```{eval-rst} -.. list-table:: - :widths: 30, 40, 30 - :header-rows: 1 - - * - Unit - - Description - - Value - * - ``B`` - - Bytes - - 1 - * - ``kB`` - - Kilobytes - - 1024 - * - ``MB`` - - Megabytes - - 1024\ :sup:`2` - * - ``GB`` - - Gigabytes - - 1024\ :sup:`3` - * - ``TB`` - - Terabytes - - 1024\ :sup:`4` - * - ``PB`` - - Petabytes - - 1024\ :sup:`5` - * - ``EB`` - - Exabytes - - 1024\ :sup:`6` - * - ``ZB`` - - Zettabytes - - 1024\ :sup:`7` - * - ``YB`` - - Yottabytes - - 1024\ :sup:`8` -``` +:::{list-table} +:widths: 30, 40, 30 +:header-rows: 1 + +* - Unit + - Description + - Value +* - ``B`` + - Bytes + - 1 +* - ``kB`` + - Kilobytes + - 1024 +* - ``MB`` + - Megabytes + - 1024{sup}`2` +* - ``GB`` + - Gigabytes + - 1024{sup}`3` +* - ``TB`` + - Terabytes + - 1024{sup}`4` +* - ``PB`` + - Petabytes + - 1024{sup}`5` +* - ``EB`` + - Exabytes + - 1024{sup}`6` +* - ``ZB`` + - Zettabytes + - 1024{sup}`7` +* - ``YB`` + - Yottabytes + - 1024{sup}`8` +::: :::{function} parse_data_size(string) -> decimal(38) Parses `string` of format `value unit` into a number, where diff --git a/docs/src/main/sphinx/functions/decimal.md b/docs/src/main/sphinx/functions/decimal.md index f9b6bfb2c057..d6371f309261 100644 --- a/docs/src/main/sphinx/functions/decimal.md +++ b/docs/src/main/sphinx/functions/decimal.md @@ -10,20 +10,19 @@ The precision of a decimal type for a literal will be equal to the number of dig in the literal (including trailing and leading zeros). The scale will be equal to the number of digits in the fractional part (including trailing zeros). -```{eval-rst} -.. list-table:: - :widths: 50, 50 - :header-rows: 1 - - * - Example literal - - Data type - * - ``DECIMAL '0'`` - - ``DECIMAL(1)`` - * - ``DECIMAL '12345'`` - - ``DECIMAL(5)`` - * - ``DECIMAL '0000012345.1234500000'`` - - ``DECIMAL(20, 10)`` -``` +:::{list-table} +:widths: 50, 50 +:header-rows: 1 + +* - Example literal + - Data type +* - `DECIMAL '0'` + - `DECIMAL(1)` +* - `DECIMAL '12345'` + - `DECIMAL(5)` +* - `DECIMAL '0000012345.1234500000'` + - `DECIMAL(20, 10)` +::: ## Binary arithmetic decimal operators @@ -31,41 +30,44 @@ Standard mathematical operators are supported. The table below explains precision and scale calculation rules for result. Assuming `x` is of type `DECIMAL(xp, xs)` and `y` is of type `DECIMAL(yp, ys)`. -```{eval-rst} -.. list-table:: - :widths: 30, 40, 30 - :header-rows: 1 - - * - Operation - - Result type precision - - Result type scale - * - ``x + y`` and ``x - y`` - - .. code-block:: text - - min(38, - 1 + - max(xs, ys) + - max(xp - xs, yp - ys) - ) - - ``max(xs, ys)`` - * - ``x * y`` - - ``min(38, xp + yp)`` - - ``xs + ys`` - * - ``x / y`` - - .. code-block:: text - - min(38, - xp + ys-xs - + max(0, ys-xs) - ) - - ``max(xs, ys)`` - * - ``x % y`` - - .. code-block:: text - - min(xp - xs, yp - ys) + - max(xs, bs) - - ``max(xs, ys)`` -``` +:::{list-table} +:widths: 30, 40, 30 +:header-rows: 1 + +* - Operation + - Result type precision + - Result type scale +* - `x + y` and `x - y` + - + ``` + min(38, + 1 + + max(xs, ys) + + max(xp - xs, yp - ys) + ) + ``` + - `max(xs, ys)` +* - `x * y` + - ``` + min(38, xp + yp) + ``` + - `xs + ys` +* - `x / y` + - + ``` + min(38, + xp + ys-xs + + max(0, ys-xs) + ) + ``` + - `max(xs, ys)` +* - `x % y` + - ``` + min(xp - xs, yp - ys) + + max(xs, bs) + ``` + - `max(xs, ys)` +::: If the mathematical result of the operation is not exactly representable with the precision and scale of the result data type, diff --git a/docs/src/main/sphinx/functions/json.md b/docs/src/main/sphinx/functions/json.md index f05323fc927e..249db0af144c 100644 --- a/docs/src/main/sphinx/functions/json.md +++ b/docs/src/main/sphinx/functions/json.md @@ -590,37 +590,36 @@ diverge from the expected schema. The following table shows the differences between the two modes. -```{eval-rst} -.. list-table:: - :widths: 40 20 40 - :header-rows: 1 - - * - Condition - - strict mode - - lax mode - * - Performing an operation which requires a non-array on an array, e.g.: - - ``$.key`` requires a JSON object - - ``$.floor()`` requires a numeric value - - ERROR - - The array is automatically unnested, and the operation is performed on - each array element. - * - Performing an operation which requires an array on an non-array, e.g.: - - ``$[0]``, ``$[*]``, ``$.size()`` - - ERROR - - The non-array item is automatically wrapped in a singleton array, and - the operation is performed on the array. - * - A structural error: accessing a non-existent element of an array or a - non-existent member of a JSON object, e.g.: - - ``$[-1]`` (array index out of bounds) - - ``$.key``, where the input JSON object does not have a member ``key`` - - ERROR - - The error is suppressed, and the operation results in an empty sequence. -``` +:::{list-table} +:widths: 40 20 40 +:header-rows: 1 + +* - Condition + - strict mode + - lax mode +* - Performing an operation which requires a non-array on an array, e.g.: + + `$.key` requires a JSON object + + `$.floor()` requires a numeric value + - ERROR + - The array is automatically unnested, and the operation is performed on + each array element. +* - Performing an operation which requires an array on an non-array, e.g.: + + `$[0]`, `$[*]`, `$.size()` + - ERROR + - The non-array item is automatically wrapped in a singleton array, and + the operation is performed on the array. +* - A structural error: accessing a non-existent element of an array or a + non-existent member of a JSON object, e.g.: + + `$[-1]` (array index out of bounds) + + `$.key`, where the input JSON object does not have a member `key` + - ERROR + - The error is suppressed, and the operation results in an empty sequence. +::: #### Examples of the lax mode behavior diff --git a/docs/src/main/sphinx/installation/deployment.md b/docs/src/main/sphinx/installation/deployment.md index 5c6eb3019cbd..c82c6b423484 100644 --- a/docs/src/main/sphinx/installation/deployment.md +++ b/docs/src/main/sphinx/installation/deployment.md @@ -302,30 +302,29 @@ The installation provides a `bin/launcher` script, which requires Python in the `PATH`. The script can be used manually or as a daemon startup script. It accepts the following commands: -```{eval-rst} -.. list-table:: ``launcher`` commands - :widths: 15, 85 - :header-rows: 1 - - * - Command - - Action - * - ``run`` - - Starts the server in the foreground and leaves it running. To shut down - the server, use Ctrl+C in this terminal or the ``stop`` command from - another terminal. - * - ``start`` - - Starts the server as a daemon and returns its process ID. - * - ``stop`` - - Shuts down a server started with either ``start`` or ``run``. Sends the - SIGTERM signal. - * - ``restart`` - - Stops then restarts a running server, or starts a stopped server, - assigning a new process ID. - * - ``kill`` - - Shuts down a possibly hung server by sending the SIGKILL signal. - * - ``status`` - - Prints a status line, either *Stopped pid* or *Running as pid*. -``` +:::{list-table} `launcher` commands +:widths: 15, 85 +:header-rows: 1 + +* - Command + - Action +* - `run` + - Starts the server in the foreground and leaves it running. To shut down + the server, use Ctrl+C in this terminal or the `stop` command from + another terminal. +* - `start` + - Starts the server as a daemon and returns its process ID. +* - `stop` + - Shuts down a server started with either `start` or `run`. Sends the + SIGTERM signal. +* - `restart` + - Stops then restarts a running server, or starts a stopped server, + assigning a new process ID. +* - `kill` + - Shuts down a possibly hung server by sending the SIGKILL signal. +* - `status` + - Prints a status line, either *Stopped pid* or *Running as pid*. +::: A number of additional options allow you to specify configuration file and directory locations, as well as Java options. Run the launcher with `--help` diff --git a/docs/src/main/sphinx/installation/rpm.md b/docs/src/main/sphinx/installation/rpm.md index 2bb4916d6b31..ca2ebe311014 100644 --- a/docs/src/main/sphinx/installation/rpm.md +++ b/docs/src/main/sphinx/installation/rpm.md @@ -36,24 +36,23 @@ installation, you can manage the Trino server with the `service` command: service trino [start|stop|restart|status] ``` -```{eval-rst} -.. list-table:: ``service`` commands - :widths: 15, 85 - :header-rows: 1 - - * - Command - - Action - * - ``start`` - - Starts the server as a daemon and returns its process ID. - * - ``stop`` - - Shuts down a server started with either ``start`` or ``run``. Sends the - SIGTERM signal. - * - ``restart`` - - Stops and then starts a running server, or starts a stopped server, - assigning a new process ID. - * - ``status`` - - Prints a status line, either *Stopped pid* or *Running as pid*. -``` +:::{list-table} `service` commands +:widths: 15, 85 +:header-rows: 1 + +* - Command + - Action +* - `start` + - Starts the server as a daemon and returns its process ID. +* - `stop` + - Shuts down a server started with either `start` or `run`. Sends the + SIGTERM signal. +* - `restart` + - Stops and then starts a running server, or starts a stopped server, + assigning a new process ID. +* - `status` + - Prints a status line, either *Stopped pid* or *Running as pid*. +::: ## Installation directory structure diff --git a/docs/src/main/sphinx/language.md b/docs/src/main/sphinx/language.md index 15286eb208c1..9c163f9252ef 100644 --- a/docs/src/main/sphinx/language.md +++ b/docs/src/main/sphinx/language.md @@ -10,15 +10,22 @@ operations on the connected data source. This section provides a reference to the supported SQL data types and other general characteristics of the SQL support of Trino. -A {doc}`full SQL statement and syntax reference` is -available in a separate section. +Refer to the following sections for further details: + +* [SQL statement and syntax reference](/sql) +* [SQL functions and operators](/functions) -Trino also provides {doc}`numerous SQL functions and operators`. ```{toctree} :maxdepth: 2 language/sql-support language/types +``` + +```{toctree} +:maxdepth: 1 + language/reserved +language/comments ``` diff --git a/docs/src/main/sphinx/sql/comments.md b/docs/src/main/sphinx/language/comments.md similarity index 97% rename from docs/src/main/sphinx/sql/comments.md rename to docs/src/main/sphinx/language/comments.md index 6cdb93f70129..eb16146388c5 100644 --- a/docs/src/main/sphinx/sql/comments.md +++ b/docs/src/main/sphinx/language/comments.md @@ -23,4 +23,4 @@ SELECT * FROM table; -- This comment is ignored. ## See also -[](./comment) +[](/sql/comment) diff --git a/docs/src/main/sphinx/release.md b/docs/src/main/sphinx/release.md index b63cc0d48ee3..0788d32f0352 100644 --- a/docs/src/main/sphinx/release.md +++ b/docs/src/main/sphinx/release.md @@ -7,6 +7,7 @@ ```{toctree} :maxdepth: 1 +release/release-430 release/release-429 release/release-428 release/release-427 diff --git a/docs/src/main/sphinx/release/release-430.md b/docs/src/main/sphinx/release/release-430.md new file mode 100644 index 000000000000..774a25f642c5 --- /dev/null +++ b/docs/src/main/sphinx/release/release-430.md @@ -0,0 +1,47 @@ +# Release 430 (20 Oct 2023) + +## General + +* Improve performance of queries with `GROUP BY`. ({issue}`19302`) +* Fix incorrect results for queries involving `ORDER BY` and window functions + with ordered frames. ({issue}`19399`) +* Fix incorrect results for query involving an aggregation in a correlated + subquery. ({issue}`19002`) + +## Security + +* Enforce authorization capability of client when receiving commands `RESET` and + `SET` for `SESSION AUTHORIZATION`. ({issue}`19217`) + +## JDBC driver + +* Add support for a `timezone` parameter to set the session timezone. ({issue}`19102`) + +## Iceberg connector + +* Add an option to require filters on partition columns. This can be enabled by + setting the ``iceberg.query-partition-filter-required`` configuration property + or the ``query_partition_filter_required`` session property. ({issue}`17263`) +* Improve performance when reading partition columns. ({issue}`19303`) + +## Ignite connector + +* Fix failure when a query contains `LIKE` with `ESCAPE`. ({issue}`19464`) + +## MariaDB connector + +* Add support for table statistics. ({issue}`19408`) + +## MongoDB connector + +* Fix incorrect results when a query contains several `<>` or `NOT IN` + predicates. ({issue}`19404`) + +## Oracle connector + +* Improve reliability of connecting to the source database. ({issue}`19191`) + +## SPI + +* Change the Java stack type for a `map` value to `SqlMap` and a `row` value to + `SqlRow`, which do not implement `Block`. ({issue}`18948`) diff --git a/docs/src/main/sphinx/security/built-in-system-access-control.md b/docs/src/main/sphinx/security/built-in-system-access-control.md index a2435c04cee7..cca8b6c772a2 100644 --- a/docs/src/main/sphinx/security/built-in-system-access-control.md +++ b/docs/src/main/sphinx/security/built-in-system-access-control.md @@ -20,27 +20,26 @@ contain a comma separated list of the access control property files to use Trino offers the following built-in system access control implementations: -```{eval-rst} -.. list-table:: - :widths: 20, 80 - :header-rows: 1 - - * - Name - - Description - * - ``default`` - - All operations are permitted, except for user impersonation and triggering - :doc:`/admin/graceful-shutdown`. - - This is the default access control if none are configured. - * - ``allow-all`` - - All operations are permitted. - * - ``read-only`` - - Operations that read data or metadata are permitted, but none of the - operations that write data or metadata are allowed. - * - ``file`` - - Authorization rules are specified in a config file. See - :doc:`file-system-access-control`. -``` +:::{list-table} +:widths: 20, 80 +:header-rows: 1 + +* - Name + - Description +* - `default` + - All operations are permitted, except for user impersonation and triggering + [](/admin/graceful-shutdown). + + This is the default access control if none are configured. +* - `allow-all` + - All operations are permitted. +* - `read-only` + - Operations that read data or metadata are permitted, but none of the + operations that write data or metadata are allowed. +* - `file` + - Authorization rules are specified in a config file. See + [](/security/file-system-access-control). +::: If you want to limit access on a system level in any other way than the ones listed above, you must implement a custom {doc}`/develop/system-access-control`. diff --git a/docs/src/main/sphinx/security/certificate.md b/docs/src/main/sphinx/security/certificate.md index 8c725ec81c28..dfabad8d2918 100644 --- a/docs/src/main/sphinx/security/certificate.md +++ b/docs/src/main/sphinx/security/certificate.md @@ -79,21 +79,20 @@ http-server.authentication.type=CERTIFICATE,PASSWORD The following configuration properties are also available: -```{eval-rst} -.. list-table:: Configuration properties - :widths: 50 50 - :header-rows: 1 - - * - Property name - - Description - * - ``http-server.authentication.certificate.user-mapping.pattern`` - - A regular expression pattern to :doc:`map all user names - ` for this authentication type to the format - expected by Trino. - * - ``http-server.authentication.certificate.user-mapping.file`` - - The path to a JSON file that contains a set of :doc:`user mapping - rules ` for this authentication type. -``` +:::{list-table} Configuration properties +:widths: 50 50 +:header-rows: 1 + +* - Property name + - Description +* - `http-server.authentication.certificate.user-mapping.pattern` + - A regular expression pattern to [map all user + names](/security/user-mapping) for this authentication type to the format + expected by Trino. +* - `http-server.authentication.certificate.user-mapping.file` + - The path to a JSON file that contains a set of [user mapping + rules](/security/user-mapping) for this authentication type. +::: ## Use certificate authentication with clients diff --git a/docs/src/main/sphinx/security/file-system-access-control.md b/docs/src/main/sphinx/security/file-system-access-control.md index 934c441beaf1..64ef7a615471 100644 --- a/docs/src/main/sphinx/security/file-system-access-control.md +++ b/docs/src/main/sphinx/security/file-system-access-control.md @@ -115,24 +115,25 @@ The following table summarizes the permissions required for each SQL command: Permissions required for executing functions: -```{eval-rst} -.. list-table:: - :widths: 30, 10, 15, 30 - :header-rows: 1 - - * - SQL command - - Catalog - - Function permission - - Note - * - ``SELECT function()`` - - ``read-only`` - - ``execute``, ``grant_execute*`` - - ``grant_execute`` is required when the function is used in a SECURITY DEFINER view. - * - ``SELECT FROM TABLE(table_function())`` - - ``read-only`` - - ``execute``, ``grant_execute*`` - - ``grant_execute`` is required when the function is used in a SECURITY DEFINER view. -``` +:::{list-table} +:widths: 30, 10, 20, 40 +:header-rows: 1 + +* - SQL command + - Catalog + - Function permission + - Note +* - `SELECT function()` + - + - `execute`, `grant_execute*` + - `grant_execute` is required when the function is used in a `SECURITY DEFINER` + view. +* - `SELECT FROM TABLE(table_function())` + - `all` + - `execute`, `grant_execute*` + - `grant_execute` is required when the function is used in a `SECURITY DEFINER` + view. +::: (system-file-auth-visibility)= @@ -360,8 +361,8 @@ The example below defines the following table access policy: These rules control the user's ability to execute functions. :::{note} -By default, all users have access to functions in the `system.builtin` schema. -You can override this behavior by adding a rule, but this will break most queries. +Users always have access to functions in the `system.builtin` schema, and +you cannot override this behavior by adding a rule. ::: Each function rule is composed of the following fields: diff --git a/docs/src/main/sphinx/security/jwt.md b/docs/src/main/sphinx/security/jwt.md index 1a0b15c0a921..32858227f0f7 100644 --- a/docs/src/main/sphinx/security/jwt.md +++ b/docs/src/main/sphinx/security/jwt.md @@ -74,39 +74,34 @@ http-server.authentication.jwt.key-file=https://cluster.example.net/.well-known/ The following configuration properties are available: -```{eval-rst} -.. list-table:: Configuration properties for JWT authentication - :widths: 50 50 - :header-rows: 1 - - * - Property - - Description - * - ``http-server.authentication.jwt.key-file`` - - Required. Specifies either the URL to a JWKS service or the path to a - PEM or HMAC file, as described below this table. - * - ``http-server.authentication.jwt.required-issuer`` - - Specifies a string that must match the value of the JWT's - issuer (``iss``) field in order to consider this JWT valid. - The ``iss`` field in the JWT identifies the principal that issued the - JWT. - * - ``http-server.authentication.jwt.required-audience`` - - Specifies a string that must match the value of the JWT's - Audience (``aud``) field in order to consider this JWT valid. - The ``aud`` field in the JWT identifies the recipients that the - JWT is intended for. - * - ``http-server.authentication.jwt.principal-field`` - - String to identify the field in the JWT that identifies the - subject of the JWT. The default value is ``sub``. This field is used to - create the Trino principal. - * - ``http-server.authentication.jwt.user-mapping.pattern`` - - A regular expression pattern to :doc:`map all user names - ` for this authentication system to the format - expected by the Trino server. - * - ``http-server.authentication.jwt.user-mapping.file`` - - The path to a JSON file that contains a set of - :doc:`user mapping rules ` for this - authentication system. -``` +:::{list-table} Configuration properties for JWT authentication +:widths: 50 50 +:header-rows: 1 + +* - Property + - Description +* - `http-server.authentication.jwt.key-file` + - Required. Specifies either the URL to a JWKS service or the path to a PEM or + HMAC file, as described below this table. +* - `http-server.authentication.jwt.required-issuer` + - Specifies a string that must match the value of the JWT's issuer (`iss`) + field in order to consider this JWT valid. The `iss` field in the JWT + identifies the principal that issued the JWT. +* - `http-server.authentication.jwt.required-audience` + - Specifies a string that must match the value of the JWT's Audience (`aud`) + field in order to consider this JWT valid. The `aud` field in the JWT + identifies the recipients that the JWT is intended for. +* - `http-server.authentication.jwt.principal-field` + - String to identify the field in the JWT that identifies the subject of the + JWT. The default value is `sub`. This field is used to create the Trino + principal. +* - `http-server.authentication.jwt.user-mapping.pattern` + - A regular expression pattern to [map all user names](/security/user-mapping) + for this authentication system to the format expected by the Trino server. +* - `http-server.authentication.jwt.user-mapping.file` + - The path to a JSON file that contains a set of [user mapping + rules](/security/user-mapping) for this authentication system. +::: Use the `http-server.authentication.jwt.key-file` property to specify either: diff --git a/docs/src/main/sphinx/security/oauth2.md b/docs/src/main/sphinx/security/oauth2.md index 2bce1f2b15cd..9143406977ba 100644 --- a/docs/src/main/sphinx/security/oauth2.md +++ b/docs/src/main/sphinx/security/oauth2.md @@ -18,8 +18,8 @@ Set the callback/redirect URL to `https:///oauth2 when configuring an OAuth 2.0 authorization server like an OpenID Connect (OIDC) provider. -If Web UI is enabled, set the post-logout callback URL to -`https:///ui/logout/logout.html` when configuring +If Web UI is enabled, set the post-logout callback URL to +`https:///ui/logout/logout.html` when configuring an OAuth 2.0 authentication server like an OpenID Connect (OIDC) provider. Using {doc}`TLS ` and {doc}`a configured shared secret @@ -84,89 +84,91 @@ web-ui.authentication.type=oauth2 The following configuration properties are available: -```{eval-rst} -.. list-table:: OAuth2 configuration properties - :widths: 40 60 - :header-rows: 1 - - * - Property - - Description - * - ``http-server.authentication.type`` - - The type of authentication to use. Must be set to ``oauth2`` to enable - OAuth2 authentication for the Trino coordinator. - * - ``http-server.authentication.oauth2.issuer`` - - The issuer URL of the IdP. All issued tokens must have this in the ``iss`` field. - * - ``http-server.authentication.oauth2.access-token-issuer`` - - The issuer URL of the IdP for access tokens, if different. - All issued access tokens must have this in the ``iss`` field. - Providing this value while OIDC discovery is enabled overrides the value - from the OpenID provider metadata document. - Defaults to the value of ``http-server.authentication.oauth2.issuer``. - * - ``http-server.authentication.oauth2.auth-url`` - - The authorization URL. The URL a user's browser will be redirected to in - order to begin the OAuth 2.0 authorization process. Providing this value - while OIDC discovery is enabled overrides the value from the OpenID - provider metadata document. - * - ``http-server.authentication.oauth2.token-url`` - - The URL of the endpoint on the authorization server which Trino uses to - obtain an access token. Providing this value while OIDC discovery is - enabled overrides the value from the OpenID provider metadata document. - * - ``http-server.authentication.oauth2.jwks-url`` - - The URL of the JSON Web Key Set (JWKS) endpoint on the authorization - server. It provides Trino the set of keys containing the public key - to verify any JSON Web Token (JWT) from the authorization server. - Providing this value while OIDC discovery is enabled overrides the value - from the OpenID provider metadata document. - * - ``http-server.authentication.oauth2.userinfo-url`` - - The URL of the IdPs ``/userinfo`` endpoint. If supplied then this URL is - used to validate the OAuth access token and retrieve any associated - claims. This is required if the IdP issues opaque tokens. Providing this - value while OIDC discovery is enabled overrides the value from the OpenID - provider metadata document. - * - ``http-server.authentication.oauth2.client-id`` - - The public identifier of the Trino client. - * - ``http-server.authentication.oauth2.client-secret`` - - The secret used to authorize Trino client with the authorization server. - * - ``http-server.authentication.oauth2.additional-audiences`` - - Additional audiences to trust in addition to the client ID which is - always a trusted audience. - * - ``http-server.authentication.oauth2.scopes`` - - Scopes requested by the server during the authorization challenge. See: - https://tools.ietf.org/html/rfc6749#section-3.3 - * - ``http-server.authentication.oauth2.challenge-timeout`` - - Maximum :ref:`duration ` of the authorization challenge. - Default is ``15m``. - * - ``http-server.authentication.oauth2.state-key`` - - A secret key used by the SHA-256 - `HMAC `_ - algorithm to sign the state parameter in order to ensure that the - authorization request was not forged. Default is a random string - generated during the coordinator start. - * - ``http-server.authentication.oauth2.user-mapping.pattern`` - - Regex to match against user. If matched, the user name is replaced with - first regex group. If not matched, authentication is denied. Default is - ``(.*)`` which allows any user name. - * - ``http-server.authentication.oauth2.user-mapping.file`` - - File containing rules for mapping user. See :doc:`/security/user-mapping` - for more information. - * - ``http-server.authentication.oauth2.principal-field`` - - The field of the access token used for the Trino user principal. Defaults to ``sub``. Other commonly used fields include ``sAMAccountName``, ``name``, ``upn``, and ``email``. - * - ``http-server.authentication.oauth2.oidc.discovery`` - - Enable reading the `OIDC provider metadata `_. - Default is ``true``. - * - ``http-server.authentication.oauth2.oidc.discovery.timeout`` - - The timeout when reading OpenID provider metadata. Default is ``30s``. - * - ``http-server.authentication.oauth2.oidc.use-userinfo-endpoint`` - - Use the value of ``userinfo_endpoint`` `in the provider metadata `_. - When a ``userinfo_endpoint`` value is supplied this URL is used to - validate the OAuth 2.0 access token, and retrieve any associated claims. - This flag allows ignoring the value provided in the metadata document. - Default is ``true``. - * - ``http-server.authentication.oauth2.end-session-url`` - - The URL of the endpoint on the authorization server to which user's browser - will be redirected to so that End-User will be logged out from the authorization - server when logging out from Trino. -``` +:::{list-table} OAuth2 configuration properties +:widths: 40 60 +:header-rows: 1 + +* - Property + - Description +* - `http-server.authentication.type` + - The type of authentication to use. Must be set to `oauth2` to enable OAuth2 + authentication for the Trino coordinator. +* - `http-server.authentication.oauth2.issuer` + - The issuer URL of the IdP. All issued tokens must have this in the `iss` + field. +* - `http-server.authentication.oauth2.access-token-issuer` + - The issuer URL of the IdP for access tokens, if different. All issued access + tokens must have this in the `iss` field. Providing this value while OIDC + discovery is enabled overrides the value from the OpenID provider metadata + document. Defaults to the value of + `http-server.authentication.oauth2.issuer`. +* - `http-server.authentication.oauth2.auth-url` + - The authorization URL. The URL a user's browser will be redirected to in + order to begin the OAuth 2.0 authorization process. Providing this value + while OIDC discovery is enabled overrides the value from the OpenID provider + metadata document. +* - `http-server.authentication.oauth2.token-url` + - The URL of the endpoint on the authorization server which Trino uses to + obtain an access token. Providing this value while OIDC discovery is enabled + overrides the value from the OpenID provider metadata document. +* - `http-server.authentication.oauth2.jwks-url` + - The URL of the JSON Web Key Set (JWKS) endpoint on the authorization server. + It provides Trino the set of keys containing the public key to verify any + JSON Web Token (JWT) from the authorization server. Providing this value + while OIDC discovery is enabled overrides the value from the OpenID provider + metadata document. +* - `http-server.authentication.oauth2.userinfo-url` + - The URL of the IdPs `/userinfo` endpoint. If supplied then this URL is used + to validate the OAuth access token and retrieve any associated claims. This + is required if the IdP issues opaque tokens. Providing this value while OIDC + discovery is enabled overrides the value from the OpenID provider metadata + document. +* - `http-server.authentication.oauth2.client-id` + - The public identifier of the Trino client. +* - `http-server.authentication.oauth2.client-secret` + - The secret used to authorize Trino client with the authorization server. +* - `http-server.authentication.oauth2.additional-audiences` + - Additional audiences to trust in addition to the client ID which is + always a trusted audience. +* - `http-server.authentication.oauth2.scopes` + - Scopes requested by the server during the authorization challenge. See: + https://tools.ietf.org/html/rfc6749#section-3.3 +* - `http-server.authentication.oauth2.challenge-timeout` + - Maximum [duration](prop-type-duration) of the authorization challenge. + Default is `15m`. +* - `http-server.authentication.oauth2.state-key` + - A secret key used by the SHA-256 [HMAC](https://tools.ietf.org/html/rfc2104) + algorithm to sign the state parameter in order to ensure that the + authorization request was not forged. Default is a random string generated + during the coordinator start. +* - `http-server.authentication.oauth2.user-mapping.pattern` + - Regex to match against user. If matched, the user name is replaced with + first regex group. If not matched, authentication is denied. Default is + `(.*)` which allows any user name. +* - `http-server.authentication.oauth2.user-mapping.file` + - File containing rules for mapping user. See [](/security/user-mapping) for + more information. +* - `http-server.authentication.oauth2.principal-field` + - The field of the access token used for the Trino user principal. Defaults to + `sub`. Other commonly used fields include `sAMAccountName`, `name`, + `upn`, and `email`. +* - `http-server.authentication.oauth2.oidc.discovery` + - Enable reading the [OIDC provider metadata](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). + Default is `true`. +* - `http-server.authentication.oauth2.oidc.discovery.timeout` + - The timeout when reading OpenID provider metadata. Default is `30s`. +* - `http-server.authentication.oauth2.oidc.use-userinfo-endpoint` + - Use the value of `userinfo_endpoint` in the [provider + metadata](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). + When a `userinfo_endpoint` value is supplied this URL is used to validate + the OAuth 2.0 access token, and retrieve any associated claims. This flag + allows ignoring the value provided in the metadata document. Default is + `true`. +* - `http-server.authentication.oauth2.end-session-url` + - The URL of the endpoint on the authentication server to which the user's + browser is redirected to so that End-User is logged out from the + authentication server when logging out from Trino. +::: (trino-oauth2-refresh-tokens)= @@ -215,33 +217,30 @@ http-server.authentication.oauth2.scopes=openid,offline_access [or offline] The following configuration properties are available: -```{eval-rst} -.. list-table:: OAuth2 configuration properties for refresh flow - :widths: 40 60 - :header-rows: 1 - - * - Property - - Description - * - ``http-server.authentication.oauth2.refresh-tokens.issued-token.timeout`` - - Expiration time for an issued token, which is the Trino-encrypted token - that contains an access token and a refresh token. The timeout value must - be less than or equal to the :ref:`duration ` of the - refresh token expiration issued by the IdP. Defaults to ``1h``. The - timeout value is the maximum session time for an OAuth2-authenticated - client with refresh tokens enabled. For more details, see - :ref:`trino-oauth2-troubleshooting`. - * - ``http-server.authentication.oauth2.refresh-tokens.issued-token.issuer`` - - Issuer representing the coordinator instance, that is referenced in the - issued token, defaults to ``Trino_coordinator``. The current - Trino version is appended to the value. This is mainly used for - debugging purposes. - * - ``http-server.authentication.oauth2.refresh-tokens.issued-token.audience`` - - Audience representing this coordinator instance, that is used in the - issued token. Defaults to ``Trino_coordinator``. - * - ``http-server.authentication.oauth2.refresh-tokens.secret-key`` - - Base64-encoded secret key used to encrypt the generated token. - By default it's generated during startup. -``` +:::{list-table} OAuth2 configuration properties for refresh flow +:widths: 40 60 +:header-rows: 1 + +* - Property + - Description +* - `http-server.authentication.oauth2.refresh-tokens.issued-token.timeout` + - Expiration time for an issued token, which is the Trino-encrypted token that + contains an access token and a refresh token. The timeout value must be less + than or equal to the [duration](prop-type-duration) of the refresh token + expiration issued by the IdP. Defaults to `1h`. The timeout value is the + maximum session time for an OAuth2-authenticated client with refresh tokens + enabled. For more details, see [](trino-oauth2-troubleshooting). +* - `http-server.authentication.oauth2.refresh-tokens.issued-token.issuer` + - Issuer representing the coordinator instance, that is referenced in the + issued token, defaults to `Trino_coordinator`. The current Trino version is + appended to the value. This is mainly used for debugging purposes. +* - `http-server.authentication.oauth2.refresh-tokens.issued-token.audience` + - Audience representing this coordinator instance, that is used in the + issued token. Defaults to `Trino_coordinator`. +* - `http-server.authentication.oauth2.refresh-tokens.secret-key` + - Base64-encoded secret key used to encrypt the generated token. By default + it's generated during startup. +::: (trino-oauth2-troubleshooting)= diff --git a/docs/src/main/sphinx/sql.md b/docs/src/main/sphinx/sql.md index 6305cdfd323c..59a65fa9d41c 100644 --- a/docs/src/main/sphinx/sql.md +++ b/docs/src/main/sphinx/sql.md @@ -1,10 +1,12 @@ # SQL statement syntax -This section describes the SQL syntax used in Trino. +This section describes the syntax for SQL statements that can be executed in +Trino. -A {doc}`reference to the supported SQL data types` is available. +Refer to the following sections for further details: -Trino also provides {doc}`numerous SQL functions and operators`. +* [SQL data types and other general aspects](/language) +* [SQL functions and operators](/functions) ```{toctree} :maxdepth: 1 @@ -16,7 +18,6 @@ sql/alter-view sql/analyze sql/call sql/comment -sql/comments sql/commit sql/create-materialized-view sql/create-role @@ -44,7 +45,6 @@ sql/grant-roles sql/insert sql/match-recognize sql/merge -sql/pattern-recognition-in-window sql/prepare sql/refresh-materialized-view sql/reset-session @@ -78,3 +78,9 @@ sql/update sql/use sql/values ``` + +```{toctree} +:hidden: + +sql/pattern-recognition-in-window +``` diff --git a/docs/src/main/sphinx/sql/comment.md b/docs/src/main/sphinx/sql/comment.md index 317fb0cb8859..8ed32a746ae1 100644 --- a/docs/src/main/sphinx/sql/comment.md +++ b/docs/src/main/sphinx/sql/comment.md @@ -32,4 +32,4 @@ COMMENT ON COLUMN users.name IS 'full name'; ## See also -[](./comments) +[](/language/comments) diff --git a/docs/src/main/sphinx/sql/select.md b/docs/src/main/sphinx/sql/select.md index f9cb799b1bd7..c3d9fc092816 100644 --- a/docs/src/main/sphinx/sql/select.md +++ b/docs/src/main/sphinx/sql/select.md @@ -711,8 +711,7 @@ A window specification has the following components: consists of the rows matched by a pattern starting from that row. Additionally, if the frame specifies row pattern measures, they can be called over the window, similarly to window functions. For more details, see - {doc}`Row pattern recognition in window structures - `. + [Row pattern recognition in window structures](/sql/pattern-recognition-in-window) . Each window component is optional. If a window specification does not specify window partitioning, ordering or frame, those components are obtained from diff --git a/docs/src/main/sphinx/sql/show-stats.md b/docs/src/main/sphinx/sql/show-stats.md index 56f2c0b407dd..2c5f27af0123 100644 --- a/docs/src/main/sphinx/sql/show-stats.md +++ b/docs/src/main/sphinx/sql/show-stats.md @@ -19,40 +19,39 @@ table lists the returned columns and what statistics they represent. Any additional statistics collected on the data source, other than those listed here, are not included. -```{eval-rst} -.. list-table:: Statistics - :widths: 20, 40, 40 - :header-rows: 1 +:::{list-table} Statistics +:widths: 20, 40, 40 +:header-rows: 1 - * - Column - - Description - - Notes - * - ``column_name`` - - The name of the column - - ``NULL`` in the table summary row - * - ``data_size`` - - The total size in bytes of all of the values in the column - - ``NULL`` in the table summary row. Available for columns of :ref:`string - ` data types with variable widths. - * - ``distinct_values_count`` - - The estimated number of distinct values in the column - - ``NULL`` in the table summary row - * - ``nulls_fractions`` - - The portion of the values in the column that are ``NULL`` - - ``NULL`` in the table summary row. - * - ``row_count`` - - The estimated number of rows in the table - - ``NULL`` in column statistic rows - * - ``low_value`` - - The lowest value found in this column - - ``NULL`` in the table summary row. Available for columns of :ref:`DATE - `, :ref:`integer `, - :ref:`floating-point `, and - :ref:`fixed-precision ` data types. - * - ``high_value`` - - The highest value found in this column - - ``NULL`` in the table summary row. Available for columns of :ref:`DATE - `, :ref:`integer `, - :ref:`floating-point `, and - :ref:`fixed-precision ` data types. -``` +* - Column + - Description + - Notes +* - `column_name` + - The name of the column + - `NULL` in the table summary row +* - `data_size` + - The total size in bytes of all of the values in the column + - `NULL` in the table summary row. Available for columns of + [string](string-data-types) data types with variable widths. +* - `distinct_values_count` + - The estimated number of distinct values in the column + - `NULL` in the table summary row +* - `nulls_fractions` + - The portion of the values in the column that are `NULL` + - `NULL` in the table summary row. +* - `row_count` + - The estimated number of rows in the table + - `NULL` in column statistic rows +* - `low_value` + - The lowest value found in this column + - `NULL` in the table summary row. Available for columns of + [DATE](date-data-type), [integer](integer-data-types), + [floating-point](floating-point-data-types), and + [fixed-precision](fixed-precision-data-types) data types. +* - `high_value` + - The highest value found in this column + - `NULL` in the table summary row. Available for columns of + [DATE](date-data-type), [integer](integer-data-types), + [floating-point](floating-point-data-types), and + [fixed-precision](fixed-precision-data-types) data types. + ::: diff --git a/lib/trino-array/pom.xml b/lib/trino-array/pom.xml index 9c90946528d7..d26078498f46 100644 --- a/lib/trino-array/pom.xml +++ b/lib/trino-array/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-cache/pom.xml b/lib/trino-cache/pom.xml index dd149047b543..2a59f03ebd10 100644 --- a/lib/trino-cache/pom.xml +++ b/lib/trino-cache/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -37,6 +37,12 @@ modernizer-maven-annotations + + io.airlift + junit-extensions + test + + io.airlift testing @@ -55,10 +61,43 @@ test + + org.junit.jupiter + junit-jupiter-api + test + + org.testng testng test + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + + org.apache.maven.surefire + surefire-junit-platform + ${dep.plugin.surefire.version} + + + org.apache.maven.surefire + surefire-testng + ${dep.plugin.surefire.version} + + + org.junit.jupiter + junit-jupiter-engine + ${dep.junit.version} + + + + + diff --git a/lib/trino-cache/src/test/java/io/trino/cache/TestEmptyCache.java b/lib/trino-cache/src/test/java/io/trino/cache/TestEmptyCache.java index f1a8beb53003..1a9c5156cace 100644 --- a/lib/trino-cache/src/test/java/io/trino/cache/TestEmptyCache.java +++ b/lib/trino-cache/src/test/java/io/trino/cache/TestEmptyCache.java @@ -15,7 +15,7 @@ import com.google.common.cache.Cache; import com.google.common.cache.CacheLoader; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; @@ -34,7 +34,7 @@ public class TestEmptyCache { private static final int TEST_TIMEOUT_MILLIS = 10_000; - @Test(timeOut = TEST_TIMEOUT_MILLIS) + @Test public void testLoadFailure() throws Exception { diff --git a/lib/trino-cache/src/test/java/io/trino/cache/TestSafeCaches.java b/lib/trino-cache/src/test/java/io/trino/cache/TestSafeCaches.java index 4528b94bbdd6..99c4c10db81c 100644 --- a/lib/trino-cache/src/test/java/io/trino/cache/TestSafeCaches.java +++ b/lib/trino-cache/src/test/java/io/trino/cache/TestSafeCaches.java @@ -17,7 +17,7 @@ import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.collect.ImmutableList; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Map; diff --git a/lib/trino-filesystem-azure/pom.xml b/lib/trino-filesystem-azure/pom.xml index 4c9ee30afc96..e93616e1b1a4 100644 --- a/lib/trino-filesystem-azure/pom.xml +++ b/lib/trino-filesystem-azure/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-filesystem-manager/pom.xml b/lib/trino-filesystem-manager/pom.xml index a7af3717823c..fbd6ec2d28ff 100644 --- a/lib/trino-filesystem-manager/pom.xml +++ b/lib/trino-filesystem-manager/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-filesystem-s3/pom.xml b/lib/trino-filesystem-s3/pom.xml index 8994913f0138..964336853ef6 100644 --- a/lib/trino-filesystem-s3/pom.xml +++ b/lib/trino-filesystem-s3/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-filesystem/pom.xml b/lib/trino-filesystem/pom.xml index 32ac16ed9b02..bba381869817 100644 --- a/lib/trino-filesystem/pom.xml +++ b/lib/trino-filesystem/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-geospatial-toolkit/pom.xml b/lib/trino-geospatial-toolkit/pom.xml index 99aa31efddff..f0cd92abb9cd 100644 --- a/lib/trino-geospatial-toolkit/pom.xml +++ b/lib/trino-geospatial-toolkit/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-hadoop-toolkit/pom.xml b/lib/trino-hadoop-toolkit/pom.xml index 94a0b834c6c7..6a9b0aaedba2 100644 --- a/lib/trino-hadoop-toolkit/pom.xml +++ b/lib/trino-hadoop-toolkit/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -33,8 +33,8 @@ - org.testng - testng + org.junit.jupiter + junit-jupiter-api test diff --git a/lib/trino-hadoop-toolkit/src/test/java/io/trino/hadoop/TestDummy.java b/lib/trino-hadoop-toolkit/src/test/java/io/trino/hadoop/TestDummy.java index 05fd986b32a9..ad45437142b3 100644 --- a/lib/trino-hadoop-toolkit/src/test/java/io/trino/hadoop/TestDummy.java +++ b/lib/trino-hadoop-toolkit/src/test/java/io/trino/hadoop/TestDummy.java @@ -13,7 +13,7 @@ */ package io.trino.hadoop; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; public class TestDummy { diff --git a/lib/trino-hdfs/pom.xml b/lib/trino-hdfs/pom.xml index b63fc08593aa..2f8b90c2020c 100644 --- a/lib/trino-hdfs/pom.xml +++ b/lib/trino-hdfs/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-hive-formats/pom.xml b/lib/trino-hive-formats/pom.xml index cb0ab2d5365b..f6ad572c5ded 100644 --- a/lib/trino-hive-formats/pom.xml +++ b/lib/trino-hive-formats/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java index 02e40bf6c8a2..d05069d1ce90 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java @@ -85,19 +85,19 @@ public BinaryColumnEncoding getEncoding(Type type) if (type instanceof TimestampType) { return new TimestampEncoding((TimestampType) type, timeZone); } - if (type instanceof ArrayType) { - return new ListEncoding(type, getEncoding(type.getTypeParameters().get(0))); + if (type instanceof ArrayType arrayType) { + return new ListEncoding(arrayType, getEncoding(arrayType.getElementType())); } - if (type instanceof MapType) { + if (type instanceof MapType mapType) { return new MapEncoding( - type, - getEncoding(type.getTypeParameters().get(0)), - getEncoding(type.getTypeParameters().get(1))); + mapType, + getEncoding(mapType.getKeyType()), + getEncoding(mapType.getValueType())); } - if (type instanceof RowType) { + if (type instanceof RowType rowType) { return new StructEncoding( - type, - type.getTypeParameters().stream() + rowType, + rowType.getTypeParameters().stream() .map(this::getEncoding) .collect(Collectors.toList())); } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java index 4aff509f9d29..5843b3ad853d 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java @@ -19,25 +19,27 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; +import io.trino.spi.type.ArrayType; import static java.lang.Math.toIntExact; public class ListEncoding extends BlockEncoding { + private final ArrayType arrayType; private final BinaryColumnEncoding elementEncoding; - public ListEncoding(Type type, BinaryColumnEncoding elementEncoding) + public ListEncoding(ArrayType arrayType, BinaryColumnEncoding elementEncoding) { - super(type); + super(arrayType); + this.arrayType = arrayType; this.elementEncoding = elementEncoding; } @Override public void encodeValue(Block block, int position, SliceOutput output) { - Block list = block.getObject(position, Block.class); + Block list = arrayType.getObject(block, position); ReadWriteUtils.writeVInt(output, list.getPositionCount()); // write null bits diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java index 9238b8585ef4..ace73d1f942e 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java @@ -22,19 +22,21 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; -import io.trino.spi.type.Type; +import io.trino.spi.type.MapType; import static java.lang.Math.toIntExact; public class MapEncoding extends BlockEncoding { + private final MapType mapType; private final BinaryColumnEncoding keyReader; private final BinaryColumnEncoding valueReader; - public MapEncoding(Type type, BinaryColumnEncoding keyReader, BinaryColumnEncoding valueReader) + public MapEncoding(MapType mapType, BinaryColumnEncoding keyReader, BinaryColumnEncoding valueReader) { - super(type); + super(mapType); + this.mapType = mapType; this.keyReader = keyReader; this.valueReader = valueReader; } @@ -42,7 +44,7 @@ public MapEncoding(Type type, BinaryColumnEncoding keyReader, BinaryColumnEncodi @Override public void encodeValue(Block block, int position, SliceOutput output) { - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java index 33d26e9d03cb..038deae7b97a 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java @@ -20,7 +20,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; -import io.trino.spi.type.Type; +import io.trino.spi.type.RowType; import java.util.List; @@ -28,17 +28,19 @@ public class StructEncoding extends BlockEncoding { private final List structFields; + private final RowType rowType; - public StructEncoding(Type type, List structFields) + public StructEncoding(RowType rowType, List structFields) { - super(type); + super(rowType); + this.rowType = rowType; this.structFields = ImmutableList.copyOf(structFields); } @Override public void encodeValue(Block block, int position, SliceOutput output) { - SqlRow row = block.getObject(position, SqlRow.class); + SqlRow row = rowType.getObject(block, position); int rawIndex = row.getRawIndex(); // write values diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java index 2074e042a99b..a50d7ae5e078 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java @@ -19,17 +19,19 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; +import io.trino.spi.type.ArrayType; public class ListEncoding extends BlockEncoding { + private final ArrayType arrayType; private final byte separator; private final TextColumnEncoding elementEncoding; - public ListEncoding(Type type, Slice nullSequence, byte separator, Byte escapeByte, TextColumnEncoding elementEncoding) + public ListEncoding(ArrayType arrayType, Slice nullSequence, byte separator, Byte escapeByte, TextColumnEncoding elementEncoding) { - super(type, nullSequence, escapeByte); + super(arrayType, nullSequence, escapeByte); + this.arrayType = arrayType; this.separator = separator; this.elementEncoding = elementEncoding; } @@ -38,7 +40,7 @@ public ListEncoding(Type type, Slice nullSequence, byte separator, Byte escapeBy public void encodeValueInto(Block block, int position, SliceOutput output) throws FileCorruptionException { - Block list = block.getObject(position, Block.class); + Block list = arrayType.getObject(block, position); for (int elementIndex = 0; elementIndex < list.getPositionCount(); elementIndex++) { if (elementIndex > 0) { output.writeByte(separator); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java index 40840cd205e9..9006bd3325ce 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java @@ -24,7 +24,6 @@ import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; import io.trino.spi.type.MapType; -import io.trino.spi.type.Type; public class MapEncoding extends BlockEncoding @@ -39,7 +38,7 @@ public class MapEncoding private BlockBuilder keyBlockBuilder; public MapEncoding( - Type type, + MapType mapType, Slice nullSequence, byte elementSeparator, byte keyValueSeparator, @@ -47,8 +46,8 @@ public MapEncoding( TextColumnEncoding keyEncoding, TextColumnEncoding valueEncoding) { - super(type, nullSequence, escapeByte); - this.mapType = (MapType) type; + super(mapType, nullSequence, escapeByte); + this.mapType = mapType; this.elementSeparator = elementSeparator; this.keyValueSeparator = keyValueSeparator; this.keyEncoding = keyEncoding; @@ -61,7 +60,7 @@ public MapEncoding( public void encodeValueInto(Block block, int position, SliceOutput output) throws FileCorruptionException { - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java index b40f77d5daba..fb78ce553b7a 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java @@ -20,26 +20,28 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; -import io.trino.spi.type.Type; +import io.trino.spi.type.RowType; import java.util.List; public class StructEncoding extends BlockEncoding { + private final RowType rowType; private final byte separator; private final boolean lastColumnTakesRest; private final List structFields; public StructEncoding( - Type type, + RowType rowType, Slice nullSequence, byte separator, Byte escapeByte, boolean lastColumnTakesRest, List structFields) { - super(type, nullSequence, escapeByte); + super(rowType, nullSequence, escapeByte); + this.rowType = rowType; this.separator = separator; this.lastColumnTakesRest = lastColumnTakesRest; this.structFields = structFields; @@ -49,7 +51,7 @@ public StructEncoding( public void encodeValueInto(Block block, int position, SliceOutput output) throws FileCorruptionException { - SqlRow row = block.getObject(position, SqlRow.class); + SqlRow row = rowType.getObject(block, position); int rawIndex = row.getRawIndex(); for (int fieldIndex = 0; fieldIndex < structFields.size(); fieldIndex++) { if (fieldIndex > 0) { diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java index 60f46091f4c7..24dbb1e33f4f 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java @@ -115,20 +115,20 @@ private TextColumnEncoding getEncoding(Type type, int depth) if (type instanceof TimestampType) { return new TimestampEncoding((TimestampType) type, textEncodingOptions.getNullSequence(), textEncodingOptions.getTimestampFormats()); } - if (type instanceof ArrayType) { - TextColumnEncoding elementEncoding = getEncoding(type.getTypeParameters().get(0), depth + 1); + if (type instanceof ArrayType arrayType) { + TextColumnEncoding elementEncoding = getEncoding(arrayType.getElementType(), depth + 1); return new ListEncoding( - type, + arrayType, textEncodingOptions.getNullSequence(), getSeparator(depth + 1), textEncodingOptions.getEscapeByte(), elementEncoding); } - if (type instanceof MapType) { - TextColumnEncoding keyEncoding = getEncoding(type.getTypeParameters().get(0), depth + 2); - TextColumnEncoding valueEncoding = getEncoding(type.getTypeParameters().get(1), depth + 2); + if (type instanceof MapType mapType) { + TextColumnEncoding keyEncoding = getEncoding(mapType.getKeyType(), depth + 2); + TextColumnEncoding valueEncoding = getEncoding(mapType.getValueType(), depth + 2); return new MapEncoding( - type, + mapType, textEncodingOptions.getNullSequence(), getSeparator(depth + 1), getSeparator(depth + 2), @@ -136,12 +136,12 @@ private TextColumnEncoding getEncoding(Type type, int depth) keyEncoding, valueEncoding); } - if (type instanceof RowType) { - List fieldEncodings = type.getTypeParameters().stream() + if (type instanceof RowType rowType) { + List fieldEncodings = rowType.getTypeParameters().stream() .map(fieldType -> getEncoding(fieldType, depth + 1)) .collect(toImmutableList()); return new StructEncoding( - type, + rowType, textEncodingOptions.getNullSequence(), getSeparator(depth + 1), textEncodingOptions.getEscapeByte(), diff --git a/lib/trino-ignite-patched/pom.xml b/lib/trino-ignite-patched/pom.xml index 5af07859f181..6e77c457dd4e 100644 --- a/lib/trino-ignite-patched/pom.xml +++ b/lib/trino-ignite-patched/pom.xml @@ -6,7 +6,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-matching/pom.xml b/lib/trino-matching/pom.xml index 51c6b55d1d83..8d99dd19d57f 100644 --- a/lib/trino-matching/pom.xml +++ b/lib/trino-matching/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-memory-context/pom.xml b/lib/trino-memory-context/pom.xml index 1ada12e7e6c4..fee00a97a54b 100644 --- a/lib/trino-memory-context/pom.xml +++ b/lib/trino-memory-context/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-orc/pom.xml b/lib/trino-orc/pom.xml index bfc317da6c23..ca646af43fdb 100644 --- a/lib/trino-orc/pom.xml +++ b/lib/trino-orc/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java b/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java index 5562871481c2..7b30700b5c37 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java @@ -21,7 +21,7 @@ import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; -import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -93,7 +93,7 @@ public static ValidationHash createValidationHash(Type type) return new ValidationHash(ROW_HASH.bindTo(rowType).bindTo(fieldHashes)); } - if (type.getTypeSignature().getBase().equals(StandardTypes.TIMESTAMP)) { + if (type instanceof TimestampType timestampType && timestampType.isShort()) { return new ValidationHash(TIMESTAMP_HASH); } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java index af5cd9195ccf..51a921ad5829 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java @@ -14,10 +14,10 @@ package io.trino.orc.writer; import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.XxHash64; import io.trino.array.IntBigArray; -import io.trino.spi.block.Block; import io.trino.spi.block.VariableWidthBlock; import java.util.Arrays; @@ -86,7 +86,7 @@ public long getRetainedSizeInBytes() blockPositionByHash.sizeOf(); } - public Block getElementBlock() + public VariableWidthBlock getElementBlock() { boolean[] isNull = new boolean[entryCount]; isNull[NULL_POSITION] = true; @@ -103,7 +103,7 @@ public void clear() Arrays.fill(offsets, 0); } - public int putIfAbsent(Block block, int position) + public int putIfAbsent(VariableWidthBlock block, int position) { requireNonNull(block, "block must not be null"); @@ -131,11 +131,14 @@ public int getEntryCount() /** * Get slot position of the element at {@code position} of {@code block} */ - private long getHashPositionOfElement(Block block, int position) + private long getHashPositionOfElement(VariableWidthBlock block, int position) { checkArgument(!block.isNull(position), "position is null"); + Slice rawSlice = block.getRawSlice(); + int rawSliceOffset = block.getRawSliceOffset(position); int length = block.getSliceLength(position); - long hashPosition = getMaskedHash(block.hash(position, 0, length)); + + long hashPosition = getMaskedHash(XxHash64.hash(rawSlice, rawSliceOffset, length)); while (true) { int entryPosition = blockPositionByHash.get(hashPosition); if (entryPosition == EMPTY_SLOT) { @@ -144,7 +147,7 @@ private long getHashPositionOfElement(Block block, int position) } int entryOffset = offsets[entryPosition]; int entryLength = offsets[entryPosition + 1] - entryOffset; - if (entryLength == length && block.bytesEqual(position, 0, sliceOutput.getUnderlyingSlice(), entryOffset, entryLength)) { + if (rawSlice.equals(rawSliceOffset, length, sliceOutput.getUnderlyingSlice(), entryOffset, entryLength)) { // Already has this element return hashPosition; } @@ -153,14 +156,13 @@ private long getHashPositionOfElement(Block block, int position) } } - private int addNewElement(long hashPosition, Block block, int position) + private int addNewElement(long hashPosition, VariableWidthBlock block, int position) { checkArgument(!block.isNull(position), "position is null"); int newElementPositionInBlock = entryCount; - int length = block.getSliceLength(position); - block.writeSliceTo(position, 0, length, sliceOutput); + sliceOutput.writeBytes(block.getRawSlice(), block.getRawSliceOffset(position), block.getSliceLength(position)); entryCount++; offsets[entryCount] = sliceOutput.size(); diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java index 045d2fee05fb..1766a2a45d1a 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java @@ -38,6 +38,7 @@ import io.trino.orc.stream.StreamDataOutput; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrays; @@ -282,17 +283,19 @@ public void writeBlock(Block block) // record values values.ensureCapacity(rowGroupValueCount + block.getPositionCount()); - for (int position = 0; position < block.getPositionCount(); position++) { - int index = dictionary.putIfAbsent(block, position); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + for (int i = 0; i < block.getPositionCount(); i++) { + int position = block.getUnderlyingValuePosition(i); + int index = dictionary.putIfAbsent(valueBlock, position); values.set(rowGroupValueCount, index); rowGroupValueCount++; totalValueCount++; - if (!block.isNull(position)) { + if (!valueBlock.isNull(position)) { // todo min/max statistics only need to be updated if value was not already in the dictionary, but non-null count does - statisticsBuilder.addValue(type.getSlice(block, position)); + statisticsBuilder.addValue(type.getSlice(valueBlock, position)); - rawBytes += block.getSliceLength(position); + rawBytes += valueBlock.getSliceLength(position); totalNonNullValueCount++; } } @@ -349,7 +352,7 @@ private void bufferOutputData() checkState(closed); checkState(!directEncoded); - Block dictionaryElements = dictionary.getElementBlock(); + VariableWidthBlock dictionaryElements = dictionary.getElementBlock(); // write dictionary in sorted order int[] sortedDictionaryIndexes = getSortedDictionaryNullsLast(dictionaryElements); @@ -404,13 +407,14 @@ private void bufferOutputData() presentStream.close(); } - private static int[] getSortedDictionaryNullsLast(Block elementBlock) + private static int[] getSortedDictionaryNullsLast(VariableWidthBlock elementBlock) { int[] sortedPositions = new int[elementBlock.getPositionCount()]; for (int i = 0; i < sortedPositions.length; i++) { sortedPositions[i] = i; } + Slice rawSlice = elementBlock.getRawSlice(); IntArrays.quickSort(sortedPositions, 0, sortedPositions.length, (int left, int right) -> { boolean nullLeft = elementBlock.isNull(left); boolean nullRight = elementBlock.isNull(right); @@ -423,13 +427,11 @@ private static int[] getSortedDictionaryNullsLast(Block elementBlock) if (nullRight) { return -1; } - return elementBlock.compareTo( - left, - 0, + return rawSlice.compareTo( + elementBlock.getRawSliceOffset(left), elementBlock.getSliceLength(left), - elementBlock, - right, - 0, + rawSlice, + elementBlock.getRawSliceOffset(right), elementBlock.getSliceLength(right)); }); diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java b/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java index e486503fe60f..3a56af57f175 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java @@ -14,7 +14,6 @@ package io.trino.orc; import com.google.common.collect.ImmutableSet; -import io.airlift.slice.Slice; import io.trino.orc.writer.DictionaryBuilder; import io.trino.spi.block.VariableWidthBlock; import org.testng.annotations.Test; @@ -34,25 +33,9 @@ public void testSkipReservedSlots() Set positions = new HashSet<>(); DictionaryBuilder dictionaryBuilder = new DictionaryBuilder(64); for (int i = 0; i < 64; i++) { - positions.add(dictionaryBuilder.putIfAbsent(new TestHashCollisionBlock(1, wrappedBuffer(new byte[] {1}), new int[] {0, 1}, new boolean[] {false}), 0)); - positions.add(dictionaryBuilder.putIfAbsent(new TestHashCollisionBlock(1, wrappedBuffer(new byte[] {2}), new int[] {0, 1}, new boolean[] {false}), 0)); + positions.add(dictionaryBuilder.putIfAbsent(new VariableWidthBlock(1, wrappedBuffer(new byte[] {1}), new int[] {0, 1}, Optional.of(new boolean[] {false})), 0)); + positions.add(dictionaryBuilder.putIfAbsent(new VariableWidthBlock(1, wrappedBuffer(new byte[] {2}), new int[] {0, 1}, Optional.of(new boolean[] {false})), 0)); } assertEquals(positions, ImmutableSet.of(1, 2)); } - - private static class TestHashCollisionBlock - extends VariableWidthBlock - { - public TestHashCollisionBlock(int positionCount, Slice slice, int[] offsets, boolean[] valueIsNull) - { - super(positionCount, slice, offsets, Optional.of(valueIsNull)); - } - - @Override - public long hash(int position, int offset, int length) - { - // return 0 to hash to the reserved null position which is zero - return 0; - } - } } diff --git a/lib/trino-parquet/pom.xml b/lib/trino-parquet/pom.xml index ad41c93f5dbd..5f7d6f5b75b2 100644 --- a/lib/trino-parquet/pom.xml +++ b/lib/trino-parquet/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-phoenix5-patched/pom.xml b/lib/trino-phoenix5-patched/pom.xml index e2a2fad91d04..ce9fc2500810 100644 --- a/lib/trino-phoenix5-patched/pom.xml +++ b/lib/trino-phoenix5-patched/pom.xml @@ -6,7 +6,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-plugin-toolkit/pom.xml b/lib/trino-plugin-toolkit/pom.xml index cad2d45ca7da..5797bcad6d79 100644 --- a/lib/trino-plugin-toolkit/pom.xml +++ b/lib/trino-plugin-toolkit/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/lib/trino-record-decoder/pom.xml b/lib/trino-record-decoder/pom.xml index ae9ee7ec2a87..2c49679d9c23 100644 --- a/lib/trino-record-decoder/pom.xml +++ b/lib/trino-record-decoder/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-accumulo-iterators/pom.xml b/plugin/trino-accumulo-iterators/pom.xml index 87247a6263aa..531fe2483965 100644 --- a/plugin/trino-accumulo-iterators/pom.xml +++ b/plugin/trino-accumulo-iterators/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-accumulo/pom.xml b/plugin/trino-accumulo/pom.xml index 352df63cf98d..a949b13a5f79 100644 --- a/plugin/trino-accumulo/pom.xml +++ b/plugin/trino-accumulo/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -15,7 +15,7 @@ ${project.parent.basedir} - 2.12.0 + 2.13.0 diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java index 141083a9e104..74d3ac2cbe01 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java @@ -19,7 +19,6 @@ import io.airlift.slice.Slice; import io.trino.plugin.accumulo.Types; import io.trino.spi.TrinoException; -import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.SqlMap; import io.trino.spi.type.Type; @@ -27,8 +26,6 @@ import java.sql.Time; import java.sql.Timestamp; -import java.util.Arrays; -import java.util.Objects; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -45,7 +42,6 @@ import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.Float.intBitsToFloat; import static java.lang.Math.floorDiv; import static java.lang.Math.toIntExact; @@ -69,57 +65,6 @@ public Field(Object nativeValue, Type type, boolean indexed) this.indexed = indexed; } - public Field(Field field) - { - this.type = field.type; - this.indexed = false; - - if (Types.isArrayType(this.type) || Types.isMapType(this.type)) { - this.value = field.value; - return; - } - - if (type.equals(BIGINT)) { - this.value = field.getLong(); - } - else if (type.equals(BOOLEAN)) { - this.value = field.getBoolean(); - } - else if (type.equals(DATE)) { - this.value = field.getDate(); - } - else if (type.equals(DOUBLE)) { - this.value = field.getDouble(); - } - else if (type.equals(INTEGER)) { - this.value = field.getInt(); - } - else if (type.equals(REAL)) { - this.value = field.getFloat(); - } - else if (type.equals(SMALLINT)) { - this.value = field.getShort(); - } - else if (type.equals(TIME_MILLIS)) { - this.value = new Time(field.getTime().getTime()); - } - else if (type.equals(TIMESTAMP_MILLIS)) { - this.value = new Timestamp(field.getTimestamp().getTime()); - } - else if (type.equals(TINYINT)) { - this.value = field.getByte(); - } - else if (type.equals(VARBINARY)) { - this.value = Arrays.copyOf(field.getVarbinary(), field.getVarbinary().length); - } - else if (type.equals(VARCHAR)) { - this.value = field.getVarchar(); - } - else { - throw new TrinoException(NOT_SUPPORTED, "Unsupported type " + type); - } - } - public Type getType() { return type; @@ -210,59 +155,6 @@ public boolean isNull() return value == null; } - @Override - public int hashCode() - { - return Objects.hash(value, type, indexed); - } - - @Override - public boolean equals(Object obj) - { - boolean retval = true; - if (obj instanceof Field field) { - if (type.equals(field.getType())) { - if (this.isNull() && field.isNull()) { - retval = true; - } - else if (this.isNull() != field.isNull()) { - retval = false; - } - else if (type.equals(VARBINARY)) { - // special case for byte arrays - // aren't they so fancy - retval = Arrays.equals((byte[]) value, (byte[]) field.getObject()); - } - else if (type.equals(DATE) || type.equals(TIME_MILLIS) || type.equals(TIMESTAMP_MILLIS)) { - retval = value.toString().equals(field.getObject().toString()); - } - else { - if (value instanceof Block) { - retval = equals((Block) value, (Block) field.getObject()); - } - else { - retval = value.equals(field.getObject()); - } - } - } - } - return retval; - } - - private static boolean equals(Block block1, Block block2) - { - boolean retval = block1.getPositionCount() == block2.getPositionCount(); - for (int i = 0; i < block1.getPositionCount() && retval; ++i) { - if (block1 instanceof ArrayBlock && block2 instanceof ArrayBlock) { - retval = equals(block1.getObject(i, Block.class), block2.getObject(i, Block.class)); - } - else { - retval = block1.compareTo(i, 0, block1.getSliceLength(i), block2, i, 0, block2.getSliceLength(i)) == 0; - } - } - return retval; - } - @Override public String toString() { diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java index bf43b5b60484..3b5c674a5fc9 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java @@ -16,10 +16,7 @@ import io.trino.spi.type.Type; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static java.util.Objects.requireNonNull; @@ -29,12 +26,6 @@ public class Row public Row() {} - public Row(Row row) - { - requireNonNull(row, "row is null"); - fields.addAll(row.fields.stream().map(Field::new).collect(Collectors.toList())); - } - public Row addField(Field field) { requireNonNull(field, "field is null"); @@ -54,33 +45,11 @@ public Field getField(int i) return fields.get(i); } - /** - * Gets a list of all internal fields. Any changes to this list will affect this row. - * - * @return List of fields - */ - public List getFields() - { - return fields; - } - public int length() { return fields.size(); } - @Override - public int hashCode() - { - return Arrays.hashCode(fields.toArray()); - } - - @Override - public boolean equals(Object obj) - { - return obj instanceof Row && Objects.equals(this.fields, ((Row) obj).getFields()); - } - @Override public String toString() { diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java index d1ba43f12262..e83836255815 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java @@ -21,6 +21,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeUtils; @@ -604,12 +605,12 @@ else if (type instanceof MapType mapType) { */ static Object readObject(Type type, Block block, int position) { - if (Types.isArrayType(type)) { - Type elementType = Types.getElementType(type); - return getArrayFromBlock(elementType, block.getObject(position, Block.class)); + if (type instanceof ArrayType arrayType) { + Type elementType = arrayType.getElementType(); + return getArrayFromBlock(elementType, arrayType.getObject(block, position)); } - if (Types.isMapType(type)) { - return getMapFromSqlMap(type, block.getObject(position, SqlMap.class)); + if (type instanceof MapType mapType) { + return getMapFromSqlMap(type, mapType.getObject(block, position)); } if (type.getJavaType() == Slice.class) { Slice slice = (Slice) TypeUtils.readNativeValue(type, block, position); diff --git a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java index 825b9c5b6d13..9edc7755aaec 100644 --- a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java +++ b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java @@ -68,9 +68,6 @@ public void testArray() assertEquals(f1.getArray(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -86,9 +83,6 @@ public void testBoolean() assertEquals(f1.getBoolean().booleanValue(), false); assertEquals(f1.getObject(), false); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -100,9 +94,6 @@ public void testDate() assertEquals(f1.getDate(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -114,9 +105,6 @@ public void testDouble() assertEquals(f1.getDouble(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -128,9 +116,6 @@ public void testFloat() assertEquals(f1.getFloat(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -142,9 +127,6 @@ public void testInt() assertEquals(f1.getInt(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -156,9 +138,6 @@ public void testLong() assertEquals(f1.getLong(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -183,9 +162,6 @@ public void testSmallInt() assertEquals(f1.getShort(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -197,9 +173,6 @@ public void testTime() assertEquals(f1.getTime(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -211,9 +184,6 @@ public void testTimestamp() assertEquals(f1.getTimestamp(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -225,9 +195,6 @@ public void testTinyInt() assertEquals(f1.getByte(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -239,9 +206,6 @@ public void testVarbinary() assertEquals(f1.getVarbinary(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -253,8 +217,5 @@ public void testVarchar() assertEquals(f1.getVarchar(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } } diff --git a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java index 87787049275e..67de0e6a892d 100644 --- a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java +++ b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java @@ -63,9 +63,6 @@ public void testRow() r1.addField(null, VARCHAR); assertEquals(r1.length(), 14); - - Row r2 = new Row(r1); - assertEquals(r2, r1); } @Test diff --git a/plugin/trino-atop/pom.xml b/plugin/trino-atop/pom.xml index ca3647921d3f..8ae55d7c68e5 100644 --- a/plugin/trino-atop/pom.xml +++ b/plugin/trino-atop/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-base-jdbc/pom.xml b/plugin/trino-base-jdbc/pom.xml index 22349925a52f..193eaa31d1bd 100644 --- a/plugin/trino-base-jdbc/pom.xml +++ b/plugin/trino-base-jdbc/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForLazyConnectionFactory.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForLazyConnectionFactory.java deleted file mode 100644 index 21ee99114bd5..000000000000 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForLazyConnectionFactory.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.jdbc; - -import com.google.inject.BindingAnnotation; - -import java.lang.annotation.Retention; -import java.lang.annotation.Target; - -import static java.lang.annotation.ElementType.FIELD; -import static java.lang.annotation.ElementType.METHOD; -import static java.lang.annotation.ElementType.PARAMETER; -import static java.lang.annotation.RetentionPolicy.RUNTIME; - -@Retention(RUNTIME) -@Target({FIELD, PARAMETER, METHOD}) -@BindingAnnotation -public @interface ForLazyConnectionFactory -{ -} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDiagnosticModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDiagnosticModule.java index 557095b85053..93bed3852ec8 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDiagnosticModule.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDiagnosticModule.java @@ -18,6 +18,7 @@ import com.google.inject.Module; import com.google.inject.Provider; import com.google.inject.Provides; +import com.google.inject.Scopes; import com.google.inject.Singleton; import io.airlift.log.Logger; import io.trino.plugin.base.CatalogName; @@ -39,11 +40,12 @@ public void configure(Binder binder) { binder.install(new MBeanServerModule()); binder.install(new MBeanModule()); + binder.bind(StatisticsAwareConnectionFactory.class).in(Scopes.SINGLETON); Provider catalogName = binder.getProvider(CatalogName.class); newExporter(binder).export(Key.get(JdbcClient.class, StatsCollecting.class)) .as(generator -> generator.generatedNameOf(JdbcClient.class, catalogName.get().toString())); - newExporter(binder).export(Key.get(ConnectionFactory.class, StatsCollecting.class)) + newExporter(binder).export(StatisticsAwareConnectionFactory.class) .as(generator -> generator.generatedNameOf(ConnectionFactory.class, catalogName.get().toString())); newExporter(binder).export(JdbcClient.class) .as(generator -> generator.generatedNameOf(CachingJdbcClient.class, catalogName.get().toString())); @@ -65,12 +67,4 @@ public JdbcClient createJdbcClientWithStats(@ForBaseJdbc JdbcClient client, Cata return client; })); } - - @Provides - @Singleton - @StatsCollecting - public static ConnectionFactory createConnectionFactoryWithStats(@ForBaseJdbc ConnectionFactory connectionFactory) - { - return new StatisticsAwareConnectionFactory(connectionFactory); - } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java index fbb3d98760aa..ba18738c0d96 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java @@ -50,6 +50,7 @@ public void setup(Binder binder) install(new JdbcDiagnosticModule()); install(new IdentifierMappingModule()); install(new RemoteQueryModifierModule()); + install(new RetryingConnectionFactoryModule()); newOptionalBinder(binder, ConnectorAccessControl.class); newOptionalBinder(binder, QueryBuilder.class).setDefault().to(DefaultQueryBuilder.class).in(Scopes.SINGLETON); @@ -88,10 +89,6 @@ public void setup(Binder binder) newSetBinder(binder, ConnectorTableFunction.class); - binder.bind(ConnectionFactory.class) - .annotatedWith(ForLazyConnectionFactory.class) - .to(Key.get(ConnectionFactory.class, StatsCollecting.class)) - .in(Scopes.SINGLETON); install(conditionalModule( QueryConfig.class, QueryConfig::isReuseConnection, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/LazyConnectionFactory.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/LazyConnectionFactory.java index f056e07c7b95..284cc1ef8d01 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/LazyConnectionFactory.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/LazyConnectionFactory.java @@ -32,7 +32,7 @@ public final class LazyConnectionFactory private final ConnectionFactory delegate; @Inject - public LazyConnectionFactory(@ForLazyConnectionFactory ConnectionFactory delegate) + public LazyConnectionFactory(RetryingConnectionFactory delegate) { this.delegate = requireNonNull(delegate, "delegate is null"); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactory.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactory.java index 6063feaa0ecf..0993cdb26e59 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactory.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactory.java @@ -14,15 +14,17 @@ package io.trino.plugin.jdbc; import com.google.common.base.Throwables; +import com.google.inject.Inject; import dev.failsafe.Failsafe; import dev.failsafe.FailsafeException; import dev.failsafe.RetryPolicy; +import io.trino.plugin.jdbc.jmx.StatisticsAwareConnectionFactory; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; import java.sql.Connection; import java.sql.SQLException; -import java.sql.SQLRecoverableException; +import java.sql.SQLTransientException; import static java.time.temporal.ChronoUnit.MILLIS; import static java.time.temporal.ChronoUnit.SECONDS; @@ -31,19 +33,22 @@ public class RetryingConnectionFactory implements ConnectionFactory { - private static final RetryPolicy RETRY_POLICY = RetryPolicy.builder() - .withMaxDuration(java.time.Duration.of(30, SECONDS)) - .withMaxAttempts(5) - .withBackoff(50, 5_000, MILLIS, 4) - .handleIf(RetryingConnectionFactory::isSqlRecoverableException) - .abortOn(TrinoException.class) - .build(); + private final RetryPolicy retryPolicy; private final ConnectionFactory delegate; - public RetryingConnectionFactory(ConnectionFactory delegate) + @Inject + public RetryingConnectionFactory(StatisticsAwareConnectionFactory delegate, RetryStrategy retryStrategy) { + requireNonNull(retryStrategy); this.delegate = requireNonNull(delegate, "delegate is null"); + this.retryPolicy = RetryPolicy.builder() + .withMaxDuration(java.time.Duration.of(30, SECONDS)) + .withMaxAttempts(5) + .withBackoff(50, 5_000, MILLIS, 4) + .handleIf(retryStrategy::isExceptionRecoverable) + .abortOn(TrinoException.class) + .build(); } @Override @@ -51,7 +56,7 @@ public Connection openConnection(ConnectorSession session) throws SQLException { try { - return Failsafe.with(RETRY_POLICY) + return Failsafe.with(retryPolicy) .get(() -> delegate.openConnection(session)); } catch (FailsafeException ex) { @@ -69,9 +74,19 @@ public void close() delegate.close(); } - private static boolean isSqlRecoverableException(Throwable exception) + public interface RetryStrategy { - return Throwables.getCausalChain(exception).stream() - .anyMatch(SQLRecoverableException.class::isInstance); + boolean isExceptionRecoverable(Throwable exception); + } + + public static class DefaultRetryStrategy + implements RetryStrategy + { + @Override + public boolean isExceptionRecoverable(Throwable exception) + { + return Throwables.getCausalChain(exception).stream() + .anyMatch(SQLTransientException.class::isInstance); + } } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactoryModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactoryModule.java new file mode 100644 index 000000000000..a0815d38e84c --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingConnectionFactoryModule.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.inject.AbstractModule; +import com.google.inject.Scopes; +import io.trino.plugin.jdbc.RetryingConnectionFactory.DefaultRetryStrategy; +import io.trino.plugin.jdbc.RetryingConnectionFactory.RetryStrategy; + +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; + +public class RetryingConnectionFactoryModule + extends AbstractModule +{ + @Override + public void configure() + { + bind(RetryingConnectionFactory.class).in(Scopes.SINGLETON); + newOptionalBinder(binder(), RetryStrategy.class) + .setDefault() + .to(DefaultRetryStrategy.class) + .in(Scopes.SINGLETON); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareConnectionFactory.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareConnectionFactory.java index 6c9153997a06..9f59238ba790 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareConnectionFactory.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareConnectionFactory.java @@ -13,7 +13,9 @@ */ package io.trino.plugin.jdbc.jmx; +import com.google.inject.Inject; import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.spi.connector.ConnectorSession; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; @@ -30,7 +32,8 @@ public class StatisticsAwareConnectionFactory private final JdbcApiStats closeConnection = new JdbcApiStats(); private final ConnectionFactory delegate; - public StatisticsAwareConnectionFactory(ConnectionFactory delegate) + @Inject + public StatisticsAwareConnectionFactory(@ForBaseJdbc ConnectionFactory delegate) { this.delegate = requireNonNull(delegate, "delegate is null"); } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseCaseInsensitiveMappingTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseCaseInsensitiveMappingTest.java index e4b3c1496864..5fc3f0d9c401 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseCaseInsensitiveMappingTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseCaseInsensitiveMappingTest.java @@ -20,8 +20,9 @@ import io.trino.plugin.base.mapping.TableMappingRule; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.nio.file.Path; import java.util.List; @@ -35,9 +36,10 @@ import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; // Tests are using JSON based identifier mapping which is one for all tests -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) public abstract class BaseCaseInsensitiveMappingTest extends AbstractTestQueryFramework { @@ -45,7 +47,7 @@ public abstract class BaseCaseInsensitiveMappingTest protected abstract SqlExecutor onRemoteDatabase(); - @BeforeClass + @BeforeAll public void disableMappingRefreshVerboseLogging() { Logging logging = Logging.initialize(); @@ -126,6 +128,8 @@ protected Optional optionalFromDual() public void testSchemaNameClash() throws Exception { + updateRuleBasedIdentifierMappingFile(getMappingFile(), ImmutableList.of(), ImmutableList.of()); + String[] nameVariants = {"casesensitivename", "CaseSensitiveName", "CASESENSITIVENAME"}; assertThat(Stream.of(nameVariants) .map(name -> name.toLowerCase(ENGLISH)) @@ -155,6 +159,8 @@ public void testSchemaNameClash() public void testTableNameClash() throws Exception { + updateRuleBasedIdentifierMappingFile(getMappingFile(), ImmutableList.of(), ImmutableList.of()); + String[] nameVariants = {"casesensitivename", "CaseSensitiveName", "CASESENSITIVENAME"}; assertThat(Stream.of(nameVariants) .map(name -> name.toLowerCase(ENGLISH)) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectionCreationTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectionCreationTest.java index 565c0f85c582..bec5406e0d18 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectionCreationTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectionCreationTest.java @@ -17,9 +17,8 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.testing.AbstractTestQueryFramework; import org.intellij.lang.annotations.Language; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import java.sql.Connection; import java.sql.SQLException; @@ -34,13 +33,12 @@ import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; -@Test(singleThreaded = true) // this class is stateful, see fields public abstract class BaseJdbcConnectionCreationTest extends AbstractTestQueryFramework { protected ConnectionCountingConnectionFactory connectionFactory; - @BeforeClass + @BeforeAll public void verifySetup() { // Test expects connectionFactory to be provided with AbstractTestQueryFramework.createQueryRunner implementation @@ -48,7 +46,7 @@ public void verifySetup() connectionFactory.assertThatNoConnectionHasLeaked(); } - @AfterClass(alwaysRun = true) + @AfterAll public void destroy() throws Exception { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index a1079ee5ae87..01ccd6290212 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -44,7 +44,6 @@ import org.intellij.lang.annotations.Language; import org.testng.SkipException; import org.testng.annotations.AfterClass; -import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.ArrayList; @@ -76,7 +75,6 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION; @@ -1207,148 +1205,144 @@ public void verifySupportsJoinPushdownWithFullJoinDeclaration() .joinIsNotFullyPushedDown(); } - @Test(dataProvider = "joinOperators") - public void testJoinPushdown(JoinOperator joinOperator) + @Test + public void testJoinPushdown() { - Session session = joinPushdownEnabled(getSession()); + for (JoinOperator joinOperator : JoinOperator.values()) { + Session session = joinPushdownEnabled(getSession()); - if (!hasBehavior(SUPPORTS_JOIN_PUSHDOWN)) { - assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")) - .joinIsNotFullyPushedDown(); - return; - } + if (!hasBehavior(SUPPORTS_JOIN_PUSHDOWN)) { + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")) + .joinIsNotFullyPushedDown(); + return; + } - if (joinOperator == FULL_JOIN && !hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN)) { - // Covered by verifySupportsJoinPushdownWithFullJoinDeclaration - return; - } + if (joinOperator == FULL_JOIN && !hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN)) { + // Covered by verifySupportsJoinPushdownWithFullJoinDeclaration + return; + } - // Disable DF here for the sake of negative test cases' expected plan. With DF enabled, some operators return in DF's FilterNode and some do not. - Session withoutDynamicFiltering = Session.builder(session) - .setSystemProperty("enable_dynamic_filtering", "false") - .build(); + // Disable DF here for the sake of negative test cases' expected plan. With DF enabled, some operators return in DF's FilterNode and some do not. + Session withoutDynamicFiltering = Session.builder(session) + .setSystemProperty("enable_dynamic_filtering", "false") + .build(); + + String notDistinctOperator = "IS NOT DISTINCT FROM"; + List nonEqualities = Stream.concat( + Stream.of(JoinCondition.Operator.values()) + .filter(operator -> operator != JoinCondition.Operator.EQUAL) + .map(JoinCondition.Operator::getValue), + Stream.of(notDistinctOperator)) + .collect(toImmutableList()); + + try (TestTable nationLowercaseTable = new TestTable( + // If a connector supports Join pushdown, but does not allow CTAS, we need to make the table creation here overridable. + getQueryRunner()::execute, + "nation_lowercase", + "AS SELECT nationkey, lower(name) name, regionkey FROM nation")) { + // basic case + assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r ON n.regionkey = r.regionkey", joinOperator))).isFullyPushedDown(); + + // join over different columns + assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r ON n.nationkey = r.regionkey", joinOperator))).isFullyPushedDown(); + + // pushdown when using USING + assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r USING(regionkey)", joinOperator))).isFullyPushedDown(); + + // varchar equality predicate + assertJoinConditionallyPushedDown( + session, + format("SELECT n.name, n2.regionkey FROM nation n %s nation n2 ON n.name = n2.name", joinOperator), + hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY)); + assertJoinConditionallyPushedDown( + session, + format("SELECT n.name, nl.regionkey FROM nation n %s %s nl ON n.name = nl.name", joinOperator, nationLowercaseTable.getName()), + hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY)); - String notDistinctOperator = "IS NOT DISTINCT FROM"; - List nonEqualities = Stream.concat( - Stream.of(JoinCondition.Operator.values()) - .filter(operator -> operator != JoinCondition.Operator.EQUAL) - .map(JoinCondition.Operator::getValue), - Stream.of(notDistinctOperator)) - .collect(toImmutableList()); + // multiple bigint predicates + assertThat(query(session, format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey and n.regionkey = c.custkey", joinOperator))) + .isFullyPushedDown(); - try (TestTable nationLowercaseTable = new TestTable( - // If a connector supports Join pushdown, but does not allow CTAS, we need to make the table creation here overridable. - getQueryRunner()::execute, - "nation_lowercase", - "AS SELECT nationkey, lower(name) name, regionkey FROM nation")) { - // basic case - assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r ON n.regionkey = r.regionkey", joinOperator))).isFullyPushedDown(); - - // join over different columns - assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r ON n.nationkey = r.regionkey", joinOperator))).isFullyPushedDown(); - - // pushdown when using USING - assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r USING(regionkey)", joinOperator))).isFullyPushedDown(); - - // varchar equality predicate - assertJoinConditionallyPushedDown( - session, - format("SELECT n.name, n2.regionkey FROM nation n %s nation n2 ON n.name = n2.name", joinOperator), - hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY)); - assertJoinConditionallyPushedDown( - session, - format("SELECT n.name, nl.regionkey FROM nation n %s %s nl ON n.name = nl.name", joinOperator, nationLowercaseTable.getName()), - hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY)); - - // multiple bigint predicates - assertThat(query(session, format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey and n.regionkey = c.custkey", joinOperator))) - .isFullyPushedDown(); + // inequality + for (String operator : nonEqualities) { + // bigint inequality predicate + assertJoinConditionallyPushedDown( + withoutDynamicFiltering, + format("SELECT r.name, n.name FROM nation n %s region r ON n.regionkey %s r.regionkey", joinOperator, operator), + expectJoinPushdown(operator) && expectJoinPushdowOnInequalityOperator(joinOperator)); + + // varchar inequality predicate + assertJoinConditionallyPushedDown( + withoutDynamicFiltering, + format("SELECT n.name, nl.name FROM nation n %s %s nl ON n.name %s nl.name", joinOperator, nationLowercaseTable.getName(), operator), + expectVarcharJoinPushdown(operator) && expectJoinPushdowOnInequalityOperator(joinOperator)); + } + + // inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join + for (String operator : nonEqualities) { + assertJoinConditionallyPushedDown( + session, + format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", joinOperator, operator), + expectJoinPushdown(operator)); + } + + // varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join + for (String operator : nonEqualities) { + assertJoinConditionallyPushedDown( + session, + format("SELECT n.name, nl.name FROM nation n %s %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", joinOperator, nationLowercaseTable.getName(), operator), + expectVarcharJoinPushdown(operator)); + } + + // Join over a (double) predicate + assertThat(query(session, format("" + + "SELECT c.name, n.name " + + "FROM (SELECT * FROM customer WHERE acctbal > 8000) c " + + "%s nation n ON c.custkey = n.nationkey", joinOperator))) + .isFullyPushedDown(); - // inequality - for (String operator : nonEqualities) { - // bigint inequality predicate + // Join over a varchar equality predicate assertJoinConditionallyPushedDown( - withoutDynamicFiltering, - format("SELECT r.name, n.name FROM nation n %s region r ON n.regionkey %s r.regionkey", joinOperator, operator), - expectJoinPushdown(operator) && expectJoinPushdowOnInequalityOperator(joinOperator)); + session, + format("SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + + "%s nation n ON c.custkey = n.nationkey", joinOperator), + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)); - // varchar inequality predicate + // Join over a varchar inequality predicate assertJoinConditionallyPushedDown( - withoutDynamicFiltering, - format("SELECT n.name, nl.name FROM nation n %s %s nl ON n.name %s nl.name", joinOperator, nationLowercaseTable.getName(), operator), - expectVarcharJoinPushdown(operator) && expectJoinPushdowOnInequalityOperator(joinOperator)); - } + session, + format("SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + + "%s nation n ON c.custkey = n.nationkey", joinOperator), + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY)); - // inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join - for (String operator : nonEqualities) { + // join over aggregation assertJoinConditionallyPushedDown( session, - format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", joinOperator, operator), - expectJoinPushdown(operator)); - } + format("SELECT * FROM (SELECT regionkey rk, count(nationkey) c FROM nation GROUP BY regionkey) n " + + "%s region r ON n.rk = r.regionkey", joinOperator), + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN)); - // varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join - for (String operator : nonEqualities) { + // join over LIMIT assertJoinConditionallyPushedDown( session, - format("SELECT n.name, nl.name FROM nation n %s %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", joinOperator, nationLowercaseTable.getName(), operator), - expectVarcharJoinPushdown(operator)); - } + format("SELECT * FROM (SELECT nationkey FROM nation LIMIT 30) n " + + "%s region r ON n.nationkey = r.regionkey", joinOperator), + hasBehavior(SUPPORTS_LIMIT_PUSHDOWN)); - // Join over a (double) predicate - assertThat(query(session, format("" + - "SELECT c.name, n.name " + - "FROM (SELECT * FROM customer WHERE acctbal > 8000) c " + - "%s nation n ON c.custkey = n.nationkey", joinOperator))) - .isFullyPushedDown(); + // join over TopN + assertJoinConditionallyPushedDown( + session, + format("SELECT * FROM (SELECT nationkey FROM nation ORDER BY regionkey LIMIT 5) n " + + "%s region r ON n.nationkey = r.regionkey", joinOperator), + hasBehavior(SUPPORTS_TOPN_PUSHDOWN)); - // Join over a varchar equality predicate - assertJoinConditionallyPushedDown( - session, - format("SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + - "%s nation n ON c.custkey = n.nationkey", joinOperator), - hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)); - - // Join over a varchar inequality predicate - assertJoinConditionallyPushedDown( - session, - format("SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + - "%s nation n ON c.custkey = n.nationkey", joinOperator), - hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY)); - - // join over aggregation - assertJoinConditionallyPushedDown( - session, - format("SELECT * FROM (SELECT regionkey rk, count(nationkey) c FROM nation GROUP BY regionkey) n " + - "%s region r ON n.rk = r.regionkey", joinOperator), - hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN)); - - // join over LIMIT - assertJoinConditionallyPushedDown( - session, - format("SELECT * FROM (SELECT nationkey FROM nation LIMIT 30) n " + - "%s region r ON n.nationkey = r.regionkey", joinOperator), - hasBehavior(SUPPORTS_LIMIT_PUSHDOWN)); - - // join over TopN - assertJoinConditionallyPushedDown( - session, - format("SELECT * FROM (SELECT nationkey FROM nation ORDER BY regionkey LIMIT 5) n " + - "%s region r ON n.nationkey = r.regionkey", joinOperator), - hasBehavior(SUPPORTS_TOPN_PUSHDOWN)); - - // join over join - assertThat(query(session, "SELECT * FROM nation n, region r, customer c WHERE n.regionkey = r.regionkey AND r.regionkey = c.custkey")) - .isFullyPushedDown(); + // join over join + assertThat(query(session, "SELECT * FROM nation n, region r, customer c WHERE n.regionkey = r.regionkey AND r.regionkey = c.custkey")) + .isFullyPushedDown(); + } } } - @DataProvider - public Object[][] joinOperators() - { - return Stream.of(JoinOperator.values()).collect(toDataProvider()); - } - @Test public void testExplainAnalyzePhysicalReadWallTime() { @@ -1825,14 +1819,23 @@ public void testInsertWithoutTemporaryTable() } } - @Test(dataProvider = "batchSizeAndTotalNumberOfRowsToInsertDataProvider") - public void testWriteBatchSizeSessionProperty(Integer batchSize, Integer numberOfRows) + @Test + public void testWriteBatchSizeSessionProperty() + { + testWriteBatchSizeSessionProperty(10, 8); // number of rows < batch size + testWriteBatchSizeSessionProperty(10, 10); // number of rows = batch size + testWriteBatchSizeSessionProperty(10, 11); // number of rows > batch size + testWriteBatchSizeSessionProperty(10, 50); // number of rows = n * batch size + testWriteBatchSizeSessionProperty(10, 52); // number of rows > n * batch size + } + + private void testWriteBatchSizeSessionProperty(int batchSize, int numberOfRows) { if (!hasBehavior(SUPPORTS_CREATE_TABLE)) { throw new SkipException("CREATE TABLE is required for write_batch_size test but is not supported"); } Session session = Session.builder(getSession()) - .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "write_batch_size", batchSize.toString()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "write_batch_size", Integer.toString(batchSize)) .build(); try (TestTable table = new TestTable( @@ -1845,8 +1848,17 @@ public void testWriteBatchSizeSessionProperty(Integer batchSize, Integer numberO } } - @Test(dataProvider = "writeTaskParallelismDataProvider") - public void testWriteTaskParallelismSessionProperty(int parallelism, int numberOfRows) + @Test + public void testWriteTaskParallelismSessionProperty() + { + testWriteTaskParallelismSessionProperty(1, 10_000); + testWriteTaskParallelismSessionProperty(2, 10_000); + testWriteTaskParallelismSessionProperty(4, 10_000); + testWriteTaskParallelismSessionProperty(16, 10_000); + testWriteTaskParallelismSessionProperty(32, 10_000); + } + + private void testWriteTaskParallelismSessionProperty(int parallelism, int numberOfRows) { if (!hasBehavior(SUPPORTS_CREATE_TABLE)) { throw new SkipException("CREATE TABLE is required for write_parallelism test but is not supported"); @@ -1875,17 +1887,6 @@ public void testWriteTaskParallelismSessionProperty(int parallelism, int numberO } } - @DataProvider - public static Object[][] writeTaskParallelismDataProvider() - { - return new Object[][]{ - {1, 10_000}, - {2, 10_000}, - {4, 10_000}, - {16, 10_000}, - {32, 10_000}}; - } - private static List buildRowsForInsert(int numberOfRows) { List result = new ArrayList<>(numberOfRows); @@ -1895,18 +1896,6 @@ private static List buildRowsForInsert(int numberOfRows) return result; } - @DataProvider - public static Object[][] batchSizeAndTotalNumberOfRowsToInsertDataProvider() - { - return new Object[][] { - {10, 8}, // number of rows < batch size - {10, 10}, // number of rows = batch size - {10, 11}, // number of rows > batch size - {10, 50}, // number of rows = n * batch size - {10, 52}, // number of rows > n * batch size - }; - } - @Test public void verifySupportsNativeQueryDeclaration() { @@ -2037,19 +2026,18 @@ protected TestTable simpleTable() return new TestTable(onRemoteDatabase(), format("%s.simple_table", getSession().getSchema().orElseThrow()), "(col BIGINT)", ImmutableList.of("1", "2")); } - @DataProvider - public Object[][] fixedJoinDistributionTypes() - { - return new Object[][] {{BROADCAST}, {PARTITIONED}}; - } - - @Test(dataProvider = "fixedJoinDistributionTypes") - public void testDynamicFiltering(JoinDistributionType joinDistributionType) + @Test + public void testDynamicFiltering() { skipTestUnless(hasBehavior(SUPPORTS_DYNAMIC_FILTER_PUSHDOWN)); + assertDynamicFiltering( "SELECT * FROM orders a JOIN orders b ON a.orderkey = b.orderkey AND b.totalprice < 1000", - joinDistributionType); + BROADCAST); + + assertDynamicFiltering( + "SELECT * FROM orders a JOIN orders b ON a.orderkey = b.orderkey AND b.totalprice < 1000", + PARTITIONED); } @Test diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java index d4793de2664a..bd1682a2125c 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java @@ -17,9 +17,9 @@ import io.trino.SystemSessionProperties; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.sql.TestTable; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.util.Locale; @@ -27,14 +27,16 @@ import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public abstract class BaseJdbcTableStatisticsTest extends AbstractTestQueryFramework { // Currently this class serves as a common "interface" to define cases that should be covered. // TODO extend it to provide reusable blocks to reduce boiler-plate. - @BeforeClass + @BeforeAll public void setUpTables() { setUpTableFromTpch("region"); @@ -115,22 +117,19 @@ protected void checkEmptyTableStats(String tableName) @Test public abstract void testMaterializedView(); - @Test(dataProvider = "testCaseColumnNamesDataProvider") - public abstract void testCaseColumnNames(String tableName); - - @DataProvider - public Object[][] testCaseColumnNamesDataProvider() + @Test + public void testCaseColumnNames() { - return new Object[][] { - {"TEST_STATS_MIXED_UNQUOTED_UPPER"}, - {"test_stats_mixed_unquoted_lower"}, - {"test_stats_mixed_uNQuoTeD_miXED"}, - {"\"TEST_STATS_MIXED_QUOTED_UPPER\""}, - {"\"test_stats_mixed_quoted_lower\""}, - {"\"test_stats_mixed_QuoTeD_miXED\""}, - }; + testCaseColumnNames("TEST_STATS_MIXED_UNQUOTED_UPPER"); + testCaseColumnNames("test_stats_mixed_unquoted_lower"); + testCaseColumnNames("test_stats_mixed_uNQuoTeD_miXED"); + testCaseColumnNames("\"TEST_STATS_MIXED_QUOTED_UPPER\""); + testCaseColumnNames("\"test_stats_mixed_quoted_lower\""); + testCaseColumnNames("\"test_stats_mixed_QuoTeD_miXED\""); } + protected abstract void testCaseColumnNames(String tableName); + @Test public abstract void testNumericCornerCases(); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java index e49edac8e7e8..81b09fba8f79 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java @@ -33,9 +33,10 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.session.PropertyMetadata; import io.trino.testing.TestingConnectorSession; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.util.List; import java.util.Map; @@ -54,15 +55,16 @@ import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; -@Test(singleThreaded = true) +@TestInstance(PER_METHOD) public class TestDefaultJdbcMetadata { private TestingDatabase database; private DefaultJdbcMetadata metadata; private JdbcTableHandle tableHandle; - @BeforeMethod + @BeforeEach public void setUp() throws Exception { @@ -115,7 +117,7 @@ public void testNonTransactionalInsertValidation() }).hasMessageContaining("Query and task retries are incompatible with non-transactional inserts"); } - @AfterMethod(alwaysRun = true) + @AfterEach public void tearDown() throws Exception { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java index 663fa1585802..c894ac616299 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java @@ -24,9 +24,7 @@ import io.trino.plugin.jdbc.credential.EmptyCredentialProvider; import io.trino.testing.QueryRunner; import org.h2.Driver; -import org.intellij.lang.annotations.Language; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Optional; import java.util.Properties; @@ -38,7 +36,6 @@ import static io.trino.tpch.TpchTable.REGION; import static java.util.Objects.requireNonNull; -@Test(singleThreaded = true) // inherited from BaseJdbcConnectionCreationTest public class TestJdbcConnectionCreation extends BaseJdbcConnectionCreationTest { @@ -57,37 +54,29 @@ protected QueryRunner createQueryRunner() new TestingConnectionH2Module(connectionFactory)); } - @Test(dataProvider = "testCases") - public void testJdbcConnectionCreations(@Language("SQL") String query, int expectedJdbcConnectionsCount, Optional errorMessage) + @Test + public void testJdbcConnectionCreations() { - assertJdbcConnections(query, expectedJdbcConnectionsCount, errorMessage); - } - - @DataProvider - public Object[][] testCases() - { - return new Object[][] { - {"SELECT * FROM nation LIMIT 1", 2, Optional.empty()}, - {"SELECT * FROM nation ORDER BY nationkey LIMIT 1", 2, Optional.empty()}, - {"SELECT * FROM nation WHERE nationkey = 1", 2, Optional.empty()}, - {"SELECT avg(nationkey) FROM nation", 2, Optional.empty()}, - {"SELECT * FROM nation, region", 3, Optional.empty()}, - {"SELECT * FROM nation n, region r WHERE n.regionkey = r.regionkey", 3, Optional.empty()}, - {"SELECT * FROM nation JOIN region USING(regionkey)", 3, Optional.empty()}, - {"SELECT * FROM information_schema.schemata", 1, Optional.empty()}, - {"SELECT * FROM information_schema.tables", 1, Optional.empty()}, - {"SELECT * FROM information_schema.columns", 1, Optional.empty()}, - {"SELECT * FROM nation", 2, Optional.empty()}, - {"CREATE TABLE copy_of_nation AS SELECT * FROM nation", 6, Optional.empty()}, - {"INSERT INTO copy_of_nation SELECT * FROM nation", 6, Optional.empty()}, - {"DELETE FROM copy_of_nation WHERE nationkey = 3", 1, Optional.empty()}, - {"UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 1, Optional.empty()}, - {"MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 1, Optional.of(MODIFYING_ROWS_MESSAGE)}, - {"DROP TABLE copy_of_nation", 1, Optional.empty()}, - {"SHOW SCHEMAS", 1, Optional.empty()}, - {"SHOW TABLES", 1, Optional.empty()}, - {"SHOW STATS FOR nation", 1, Optional.empty()}, - }; + assertJdbcConnections("SELECT * FROM nation LIMIT 1", 2, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation ORDER BY nationkey LIMIT 1", 2, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation WHERE nationkey = 1", 2, Optional.empty()); + assertJdbcConnections("SELECT avg(nationkey) FROM nation", 2, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation, region", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation n, region r WHERE n.regionkey = r.regionkey", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation JOIN region USING(regionkey)", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.schemata", 1, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.tables", 1, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.columns", 1, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation", 2, Optional.empty()); + assertJdbcConnections("CREATE TABLE copy_of_nation AS SELECT * FROM nation", 6, Optional.empty()); + assertJdbcConnections("INSERT INTO copy_of_nation SELECT * FROM nation", 6, Optional.empty()); + assertJdbcConnections("DELETE FROM copy_of_nation WHERE nationkey = 3", 1, Optional.empty()); + assertJdbcConnections("UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 1, Optional.empty()); + assertJdbcConnections("MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 1, Optional.of(MODIFYING_ROWS_MESSAGE)); + assertJdbcConnections("DROP TABLE copy_of_nation", 1, Optional.empty()); + assertJdbcConnections("SHOW SCHEMAS", 1, Optional.empty()); + assertJdbcConnections("SHOW TABLES", 1, Optional.empty()); + assertJdbcConnections("SHOW STATS FOR nation", 1, Optional.empty()); } private static class TestingConnectionH2Module diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcFlushMetadataCacheProcedure.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcFlushMetadataCacheProcedure.java index 8da61b8714f9..8fd09f0a4b8e 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcFlushMetadataCacheProcedure.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcFlushMetadataCacheProcedure.java @@ -17,7 +17,7 @@ import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; import io.trino.testing.sql.JdbcSqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.Map; @@ -27,7 +27,6 @@ import static io.trino.tpch.TpchTable.NATION; import static org.assertj.core.api.Assertions.assertThatThrownBy; -@Test(singleThreaded = true) // some test assertions rely on `flush_metadata_cache()` being not executed yet, so cannot run concurrently public class TestJdbcFlushMetadataCacheProcedure extends AbstractTestQueryFramework { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java index 9bcbd10fb5e1..2c354d50aa46 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java @@ -19,8 +19,7 @@ import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; import io.trino.tpch.TpchTable; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -29,8 +28,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -// Single-threaded because of shared mutable state, e.g. onGetTableProperties -@Test(singleThreaded = true) public class TestJdbcTableProperties extends AbstractTestQueryFramework { @@ -53,12 +50,6 @@ public Map getTableProperties(ConnectorSession session, JdbcTabl return createH2QueryRunner(ImmutableList.copyOf(TpchTable.getTables()), properties, module); } - @BeforeMethod - public void reset() - { - onGetTableProperties = () -> {}; - } - @Test public void testGetTablePropertiesIsNotCalledForSelect() { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestLazyConnectionFactory.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestLazyConnectionFactory.java index 1f6ddb9a9d0a..cd0fb3dd1d6e 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestLazyConnectionFactory.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestLazyConnectionFactory.java @@ -13,6 +13,8 @@ */ package io.trino.plugin.jdbc; +import com.google.inject.Guice; +import com.google.inject.Injector; import io.trino.plugin.jdbc.credential.EmptyCredentialProvider; import org.h2.Driver; import org.junit.jupiter.api.Test; @@ -30,11 +32,15 @@ public class TestLazyConnectionFactory public void testNoConnectionIsCreated() throws Exception { - ConnectionFactory failingConnectionFactory = session -> { - throw new AssertionError("Expected no connection creation"); - }; + Injector injector = Guice.createInjector(binder -> { + binder.bind(ConnectionFactory.class).annotatedWith(ForBaseJdbc.class).toInstance( + session -> { + throw new AssertionError("Expected no connection creation"); + }); + binder.install(new RetryingConnectionFactoryModule()); + }); - try (LazyConnectionFactory lazyConnectionFactory = new LazyConnectionFactory(failingConnectionFactory); + try (LazyConnectionFactory lazyConnectionFactory = injector.getInstance(LazyConnectionFactory.class); Connection ignored = lazyConnectionFactory.openConnection(SESSION)) { // no-op } @@ -47,8 +53,13 @@ public void testConnectionCannotBeReusedAfterClose() BaseJdbcConfig config = new BaseJdbcConfig() .setConnectionUrl(format("jdbc:h2:mem:test%s;DB_CLOSE_DELAY=-1", System.nanoTime() + ThreadLocalRandom.current().nextLong())); - try (DriverConnectionFactory h2ConnectionFactory = new DriverConnectionFactory(new Driver(), config, new EmptyCredentialProvider()); - LazyConnectionFactory lazyConnectionFactory = new LazyConnectionFactory(h2ConnectionFactory)) { + Injector injector = Guice.createInjector(binder -> { + binder.bind(ConnectionFactory.class).annotatedWith(ForBaseJdbc.class).toInstance( + new DriverConnectionFactory(new Driver(), config, new EmptyCredentialProvider())); + binder.install(new RetryingConnectionFactoryModule()); + }); + + try (LazyConnectionFactory lazyConnectionFactory = injector.getInstance(LazyConnectionFactory.class)) { Connection connection = lazyConnectionFactory.openConnection(SESSION); connection.close(); assertThatThrownBy(() -> connection.createStatement()) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestRetryingConnectionFactory.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestRetryingConnectionFactory.java index 985136fc422e..d85c1c5d1ef8 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestRetryingConnectionFactory.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestRetryingConnectionFactory.java @@ -13,6 +13,13 @@ */ package io.trino.plugin.jdbc; +import com.google.common.base.Throwables; +import com.google.inject.Guice; +import com.google.inject.Inject; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Scopes; +import io.trino.plugin.jdbc.RetryingConnectionFactory.RetryStrategy; import io.trino.spi.StandardErrorCode; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; @@ -21,17 +28,20 @@ import java.sql.Connection; import java.sql.SQLException; import java.sql.SQLRecoverableException; +import java.sql.SQLTransientException; import java.util.ArrayDeque; import java.util.Deque; import java.util.stream.Stream; import static com.google.common.reflect.Reflection.newProxy; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.RETURN; import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_NPE; import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_SQL_EXCEPTION; import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_SQL_RECOVERABLE_EXCEPTION; +import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_SQL_TRANSIENT_EXCEPTION; import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_TRINO_EXCEPTION; -import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_WRAPPED_SQL_RECOVERABLE_EXCEPTION; +import static io.trino.plugin.jdbc.TestRetryingConnectionFactory.MockConnectorFactory.Action.THROW_WRAPPED_SQL_TRANSIENT_EXCEPTION; import static io.trino.spi.block.TestingSession.SESSION; import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden; import static java.util.Objects.requireNonNull; @@ -50,42 +60,55 @@ public void testEverythingImplemented() public void testSimplyReturnConnection() throws Exception { - MockConnectorFactory mock = new MockConnectorFactory(RETURN); - ConnectionFactory factory = new RetryingConnectionFactory(mock); - assertThat(factory.openConnection(SESSION)).isNotNull(); + Injector injector = createInjector(RETURN); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + + Connection connection = factory.openConnection(SESSION); + + assertThat(connection).isNotNull(); assertThat(mock.getCallCount()).isEqualTo(1); } @Test public void testRetryAndStopOnTrinoException() { - MockConnectorFactory mock = new MockConnectorFactory(THROW_SQL_RECOVERABLE_EXCEPTION, THROW_TRINO_EXCEPTION); - ConnectionFactory factory = new RetryingConnectionFactory(mock); + Injector injector = createInjector(THROW_SQL_TRANSIENT_EXCEPTION, THROW_TRINO_EXCEPTION); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + assertThatThrownBy(() -> factory.openConnection(SESSION)) .isInstanceOf(TrinoException.class) .hasMessage("Testing Trino exception"); + assertThat(mock.getCallCount()).isEqualTo(2); } @Test public void testRetryAndStopOnSqlException() { - MockConnectorFactory mock = new MockConnectorFactory(THROW_SQL_RECOVERABLE_EXCEPTION, THROW_SQL_EXCEPTION); - ConnectionFactory factory = new RetryingConnectionFactory(mock); + Injector injector = createInjector(THROW_SQL_TRANSIENT_EXCEPTION, THROW_SQL_EXCEPTION); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + assertThatThrownBy(() -> factory.openConnection(SESSION)) .isInstanceOf(SQLException.class) .hasMessage("Testing sql exception"); + assertThat(mock.getCallCount()).isEqualTo(2); } @Test public void testNullPointerException() { - MockConnectorFactory mock = new MockConnectorFactory(THROW_NPE); - ConnectionFactory factory = new RetryingConnectionFactory(mock); + Injector injector = createInjector(THROW_NPE); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + assertThatThrownBy(() -> factory.openConnection(SESSION)) .isInstanceOf(NullPointerException.class) .hasMessage("Testing NPE"); + assertThat(mock.getCallCount()).isEqualTo(1); } @@ -93,9 +116,13 @@ public void testNullPointerException() public void testRetryAndReturn() throws Exception { - MockConnectorFactory mock = new MockConnectorFactory(THROW_SQL_RECOVERABLE_EXCEPTION, RETURN); - ConnectionFactory factory = new RetryingConnectionFactory(mock); - assertThat(factory.openConnection(SESSION)).isNotNull(); + Injector injector = createInjector(THROW_SQL_TRANSIENT_EXCEPTION, RETURN); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + + Connection connection = factory.openConnection(SESSION); + + assertThat(connection).isNotNull(); assertThat(mock.getCallCount()).isEqualTo(2); } @@ -103,18 +130,69 @@ public void testRetryAndReturn() public void testRetryOnWrappedAndReturn() throws Exception { - MockConnectorFactory mock = new MockConnectorFactory(THROW_WRAPPED_SQL_RECOVERABLE_EXCEPTION, RETURN); - ConnectionFactory factory = new RetryingConnectionFactory(mock); - assertThat(factory.openConnection(SESSION)).isNotNull(); + Injector injector = createInjector(THROW_WRAPPED_SQL_TRANSIENT_EXCEPTION, RETURN); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + + Connection connection = factory.openConnection(SESSION); + + assertThat(connection).isNotNull(); + assertThat(mock.getCallCount()).isEqualTo(2); + } + + @Test + public void testOverridingRetryStrategyWorks() + throws Exception + { + Injector injector = createInjectorWithOverridenStrategy(THROW_SQL_RECOVERABLE_EXCEPTION, RETURN); + ConnectionFactory factory = injector.getInstance(RetryingConnectionFactory.class); + MockConnectorFactory mock = injector.getInstance(MockConnectorFactory.class); + + Connection connection = factory.openConnection(SESSION); + + assertThat(connection).isNotNull(); assertThat(mock.getCallCount()).isEqualTo(2); } + private static Injector createInjector(MockConnectorFactory.Action... actions) + { + return Guice.createInjector(binder -> { + binder.bind(MockConnectorFactory.Action[].class).toInstance(actions); + binder.bind(MockConnectorFactory.class).in(Scopes.SINGLETON); + binder.bind(ConnectionFactory.class).annotatedWith(ForBaseJdbc.class).to(Key.get(MockConnectorFactory.class)); + binder.install(new RetryingConnectionFactoryModule()); + }); + } + + private static Injector createInjectorWithOverridenStrategy(MockConnectorFactory.Action... actions) + { + return Guice.createInjector(binder -> { + binder.bind(MockConnectorFactory.Action[].class).toInstance(actions); + binder.bind(MockConnectorFactory.class).in(Scopes.SINGLETON); + binder.bind(ConnectionFactory.class).annotatedWith(ForBaseJdbc.class).to(Key.get(MockConnectorFactory.class)); + binder.install(new RetryingConnectionFactoryModule()); + newOptionalBinder(binder, RetryStrategy.class).setBinding().to(OverrideRetryStrategy.class).in(Scopes.SINGLETON); + }); + } + + private static class OverrideRetryStrategy + implements RetryStrategy + { + @Override + public boolean isExceptionRecoverable(Throwable exception) + { + return Throwables.getCausalChain(exception).stream() + .anyMatch(SQLRecoverableException.class::isInstance); + } + } + public static class MockConnectorFactory implements ConnectionFactory { private final Deque actions = new ArrayDeque<>(); private int callCount; + @Inject public MockConnectorFactory(Action... actions) { Stream.of(actions) @@ -145,6 +223,10 @@ public Connection openConnection(ConnectorSession session) throw new SQLRecoverableException("Testing sql recoverable exception"); case THROW_WRAPPED_SQL_RECOVERABLE_EXCEPTION: throw new RuntimeException(new SQLRecoverableException("Testing sql recoverable exception")); + case THROW_SQL_TRANSIENT_EXCEPTION: + throw new SQLTransientException("Testing sql transient exception"); + case THROW_WRAPPED_SQL_TRANSIENT_EXCEPTION: + throw new RuntimeException(new SQLTransientException("Testing sql transient exception")); } throw new IllegalStateException("Unsupported action:" + action); } @@ -155,6 +237,8 @@ public enum Action THROW_SQL_EXCEPTION, THROW_SQL_RECOVERABLE_EXCEPTION, THROW_WRAPPED_SQL_RECOVERABLE_EXCEPTION, + THROW_SQL_TRANSIENT_EXCEPTION, + THROW_WRAPPED_SQL_TRANSIENT_EXCEPTION, THROW_NPE, RETURN, } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java index faa64733a79e..64b729dd6bcd 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java @@ -16,8 +16,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.security.ConnectorIdentity; import io.trino.testing.TestingConnectorSession; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -150,8 +149,22 @@ public void testFormatQueryModifierWithTraceToken() .isEqualTo("SELECT * FROM USERS /*ttoken=valid-value*/"); } - @Test(dataProvider = "validValues") - public void testFormatWithValidValues(String value) + @Test + public void testFormatWithValidValues() + { + testFormatWithValidValues("trino"); + testFormatWithValidValues("123"); + testFormatWithValidValues("1t2r3i4n0"); + testFormatWithValidValues("trino-cli"); + testFormatWithValidValues("trino_cli"); + testFormatWithValidValues("trino-cli_123"); + testFormatWithValidValues("123_trino-cli"); + testFormatWithValidValues("123-trino_cli"); + testFormatWithValidValues("-trino-cli"); + testFormatWithValidValues("_trino_cli"); + } + + private void testFormatWithValidValues(String value) { TestingConnectorSession connectorSession = TestingConnectorSession.builder() .setIdentity(ConnectorIdentity.ofUser("Alice")) @@ -167,23 +180,6 @@ public void testFormatWithValidValues(String value) .isEqualTo("SELECT * FROM USERS /*source=%1$s ttoken=%1$s*/".formatted(value)); } - @DataProvider - public Object[][] validValues() - { - return new Object[][] { - {"trino"}, - {"123"}, - {"1t2r3i4n0"}, - {"trino-cli"}, - {"trino_cli"}, - {"trino-cli_123"}, - {"123_trino-cli"}, - {"123-trino_cli"}, - {"-trino-cli"}, - {"_trino_cli"} - }; - } - private static FormatBasedRemoteQueryModifier createRemoteQueryModifier(String commentFormat) { return new FormatBasedRemoteQueryModifier(new FormatBasedRemoteQueryModifierConfig().setFormat(commentFormat)); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java index a261fd92e39e..5f0d2d4b49a6 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java @@ -14,8 +14,7 @@ package io.trino.plugin.jdbc.logging; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Map; @@ -23,7 +22,6 @@ import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.util.Arrays.array; public class TestFormatBasedRemoteQueryModifierConfig { @@ -43,40 +41,32 @@ public void testExplicitPropertyMappings() assertFullMapping(properties, expected); } - @Test(dataProvider = "getForbiddenValuesInFormat") - public void testInvalidFormatValue(String incorrectValue) - { - assertThat(new FormatBasedRemoteQueryModifierConfig().setFormat(incorrectValue).isFormatValid()) - .isFalse(); - } - - @DataProvider - public static Object[][] getForbiddenValuesInFormat() + @Test + public void testInvalidFormatValue() { - return array( - array("*"), - array("("), - array(")"), - array("["), - array("]"), - array("{"), - array("}"), - array("&"), - array("@"), - array("!"), - array("#"), - array("%"), - array("^"), - array("$"), - array("\\"), - array("/"), - array("?"), - array(">"), - array("<"), - array(";"), - array("\""), - array(":"), - array("|")); + assertThat(configWithFormat("*").isFormatValid()).isFalse(); + assertThat(configWithFormat("(").isFormatValid()).isFalse(); + assertThat(configWithFormat(")").isFormatValid()).isFalse(); + assertThat(configWithFormat("[").isFormatValid()).isFalse(); + assertThat(configWithFormat("]").isFormatValid()).isFalse(); + assertThat(configWithFormat("{").isFormatValid()).isFalse(); + assertThat(configWithFormat("}").isFormatValid()).isFalse(); + assertThat(configWithFormat("&").isFormatValid()).isFalse(); + assertThat(configWithFormat("@").isFormatValid()).isFalse(); + assertThat(configWithFormat("!").isFormatValid()).isFalse(); + assertThat(configWithFormat("#").isFormatValid()).isFalse(); + assertThat(configWithFormat("%").isFormatValid()).isFalse(); + assertThat(configWithFormat("^").isFormatValid()).isFalse(); + assertThat(configWithFormat("$").isFormatValid()).isFalse(); + assertThat(configWithFormat("\\").isFormatValid()).isFalse(); + assertThat(configWithFormat("/").isFormatValid()).isFalse(); + assertThat(configWithFormat("?").isFormatValid()).isFalse(); + assertThat(configWithFormat(">").isFormatValid()).isFalse(); + assertThat(configWithFormat("<").isFormatValid()).isFalse(); + assertThat(configWithFormat(";").isFormatValid()).isFalse(); + assertThat(configWithFormat("\"").isFormatValid()).isFalse(); + assertThat(configWithFormat(":").isFormatValid()).isFalse(); + assertThat(configWithFormat("|").isFormatValid()).isFalse(); } @Test @@ -90,4 +80,9 @@ public void testValidFormatWithDuplicatedPredefinedValues() { assertThat(new FormatBasedRemoteQueryModifierConfig().setFormat("$QUERY_ID $QUERY_ID $USER $USER $SOURCE $SOURCE $TRACE_TOKEN $TRACE_TOKEN").isFormatValid()).isTrue(); } + + private FormatBasedRemoteQueryModifierConfig configWithFormat(String format) + { + return new FormatBasedRemoteQueryModifierConfig().setFormat(format); + } } diff --git a/plugin/trino-bigquery/pom.xml b/plugin/trino-bigquery/pom.xml index 4d6e8ee5f655..79c39c022241 100644 --- a/plugin/trino-bigquery/pom.xml +++ b/plugin/trino-bigquery/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java index 8dd65dae1d0e..3bdcb3f30627 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java @@ -111,7 +111,7 @@ public static Object readNativeValue(Type type, Block block, int position) return timestampToStringConverter(timestamp); } if (type instanceof ArrayType arrayType) { - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = arrayType.getObject(block, position); ImmutableList.Builder list = ImmutableList.builderWithExpectedSize(arrayBlock.getPositionCount()); for (int i = 0; i < arrayBlock.getPositionCount(); i++) { Object element = readNativeValue(arrayType.getElementType(), arrayBlock, i); @@ -123,7 +123,7 @@ public static Object readNativeValue(Type type, Block block, int position) return list.build(); } if (type instanceof RowType rowType) { - SqlRow sqlRow = block.getObject(position, SqlRow.class); + SqlRow sqlRow = rowType.getObject(block, position); List fieldTypes = rowType.getTypeParameters(); if (fieldTypes.size() != sqlRow.getFieldCount()) { diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java index e0bca7b0b148..692239f4e65a 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java @@ -679,10 +679,10 @@ private void assertLabelForTable(String expectedView, QueryId queryId, String tr SELECT * FROM UNNEST(labels) AS label WHERE label.key = 'trino_query' AND label.value = '%s' )""".formatted(expectedLabel); - assertThat(bigQuerySqlExecutor.executeQuery(checkForLabelQuery).getValues()) + assertEventually(() -> assertThat(bigQuerySqlExecutor.executeQuery(checkForLabelQuery).getValues()) .extracting(values -> values.get("query").getStringValue()) .singleElement() - .matches(statement -> statement.contains(expectedView)); + .matches(statement -> statement.contains(expectedView))); } @Test diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryTypeMapping.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryTypeMapping.java index ccb1cd23e667..d2130132351d 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryTypeMapping.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryTypeMapping.java @@ -29,9 +29,7 @@ import io.trino.testing.sql.SqlExecutor; import io.trino.testing.sql.TestTable; import io.trino.testing.sql.TrinoSqlExecutor; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.time.ZoneId; import java.util.Optional; @@ -57,17 +55,11 @@ public abstract class BaseBigQueryTypeMapping extends AbstractTestQueryFramework { - private BigQueryQueryRunner.BigQuerySqlExecutor bigQuerySqlExecutor; + private final BigQueryQueryRunner.BigQuerySqlExecutor bigQuerySqlExecutor = new BigQueryQueryRunner.BigQuerySqlExecutor(); private final ZoneId jvmZone = ZoneId.systemDefault(); private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); - @BeforeClass(alwaysRun = true) - public void initBigQueryExecutor() - { - bigQuerySqlExecutor = new BigQueryQueryRunner.BigQuerySqlExecutor(); - } - @Test public void testBoolean() { @@ -110,8 +102,19 @@ public void testBytes() .execute(getQueryRunner(), trinoCreateAndInsert("test.varbinary")); } - @Test(dataProvider = "bigqueryIntegerTypeProvider") - public void testInt64(String inputType) + @Test + public void testInt64() + { + testInt64("BYTEINT"); + testInt64("TINYINT"); + testInt64("SMALLINT"); + testInt64("INTEGER"); + testInt64("INT64"); + testInt64("INT"); + testInt64("BIGINT"); + } + + private void testInt64(String inputType) { SqlDataTypeTest.create() .addRoundTrip(inputType, "-9223372036854775808", BIGINT, "-9223372036854775808") @@ -122,21 +125,6 @@ public void testInt64(String inputType) .execute(getQueryRunner(), bigqueryViewCreateAndInsert("test.integer")); } - @DataProvider - public Object[][] bigqueryIntegerTypeProvider() - { - // BYTEINT, TINYINT, SMALLINT, INTEGER, INT and BIGINT are aliases for INT64 in BigQuery - return new Object[][] { - {"BYTEINT"}, - {"TINYINT"}, - {"SMALLINT"}, - {"INTEGER"}, - {"INT64"}, - {"INT"}, - {"BIGINT"}, - }; - } - @Test public void testTinyint() { @@ -383,8 +371,14 @@ public void testUnsupportedBigNumericMappingView() .hasMessageContaining("SELECT * not allowed from relation that has no columns"); } - @Test(dataProvider = "bigqueryUnsupportedBigNumericTypeProvider") - public void testUnsupportedBigNumericMapping(String unsupportedTypeName) + @Test + public void testUnsupportedBigNumericMapping() + { + testUnsupportedBigNumericMapping("BIGNUMERIC"); + testUnsupportedBigNumericMapping("BIGNUMERIC(40,2)"); + } + + private void testUnsupportedBigNumericMapping(String unsupportedTypeName) { try (TestTable table = new TestTable(getBigQuerySqlExecutor(), "test.unsupported_bignumeric", format("(supported_column INT64, unsupported_column %s)", unsupportedTypeName))) { assertQuery( @@ -393,15 +387,6 @@ public void testUnsupportedBigNumericMapping(String unsupportedTypeName) } } - @DataProvider - public Object[][] bigqueryUnsupportedBigNumericTypeProvider() - { - return new Object[][] { - {"BIGNUMERIC"}, - {"BIGNUMERIC(40,2)"}, - }; - } - @Test public void testDate() { @@ -533,8 +518,19 @@ public void testTime() .execute(getQueryRunner(), bigqueryViewCreateAndInsert("test.time")); } - @Test(dataProvider = "sessionZonesDataProvider") - public void testTimestampWithTimeZone(ZoneId zoneId) + @Test + public void testTimestampWithTimeZone() + { + testTimestampWithTimeZone(UTC); + testTimestampWithTimeZone(jvmZone); + + // using two non-JVM zones so that we don't need to worry what BigQuery system zone is + testTimestampWithTimeZone(vilnius); + testTimestampWithTimeZone(kathmandu); + testTimestampWithTimeZone(TestingSession.DEFAULT_TIME_ZONE_KEY.getZoneId()); + } + + private void testTimestampWithTimeZone(ZoneId zoneId) { Session session = Session.builder(getSession()) .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(zoneId.getId())) @@ -625,19 +621,6 @@ private SqlDataTypeTest testTimestampWithTimeZone(String inputType) TIMESTAMP_TZ_MICROS, "TIMESTAMP '9999-12-31 23:59:59.999999 UTC'"); } - @DataProvider - public Object[][] sessionZonesDataProvider() - { - return new Object[][] { - {UTC}, - {jvmZone}, - // using two non-JVM zones so that we don't need to worry what BigQuery system zone is - {vilnius}, - {kathmandu}, - {TestingSession.DEFAULT_TIME_ZONE_KEY.getZoneId()}, - }; - } - @Test public void testUnsupportedTimestampWithTimeZone() { diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryCaseInsensitiveMapping.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryCaseInsensitiveMapping.java index b377e1b5fa50..32b10c7a6289 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryCaseInsensitiveMapping.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryCaseInsensitiveMapping.java @@ -21,8 +21,7 @@ import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; import io.trino.testing.sql.TestView; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.stream.Stream; @@ -35,18 +34,11 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestBigQueryCaseInsensitiveMapping // TODO extends BaseCaseInsensitiveMappingTest - https://github.com/trinodb/trino/issues/7864 extends AbstractTestQueryFramework { - protected BigQuerySqlExecutor bigQuerySqlExecutor; - - @BeforeClass(alwaysRun = true) - public void initBigQueryExecutor() - { - this.bigQuerySqlExecutor = new BigQuerySqlExecutor(); - } + private final BigQuerySqlExecutor bigQuerySqlExecutor = new BigQuerySqlExecutor(); @Override protected QueryRunner createQueryRunner() diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryInstanceCleaner.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryInstanceCleaner.java index 340f9b1ac872..bcd4d2f966c6 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryInstanceCleaner.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryInstanceCleaner.java @@ -19,9 +19,9 @@ import io.airlift.log.Logger; import io.trino.plugin.bigquery.BigQueryQueryRunner.BigQuerySqlExecutor; import io.trino.tpch.TpchTable; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Collection; @@ -36,7 +36,9 @@ import static java.lang.String.join; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toUnmodifiableSet; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestBigQueryInstanceCleaner { public static final Logger LOG = Logger.get(TestBigQueryInstanceCleaner.class); @@ -54,7 +56,7 @@ public class TestBigQueryInstanceCleaner private BigQuerySqlExecutor bigQuerySqlExecutor; - @BeforeClass + @BeforeAll public void setUp() { this.bigQuerySqlExecutor = new BigQuerySqlExecutor(); @@ -89,8 +91,15 @@ public void cleanUpDatasets() }); } - @Test(dataProvider = "cleanUpSchemasDataProvider") - public void cleanUpTables(String schemaName) + @Test + public void cleanUpTables() + { + // Other schemas created by tests are taken care of by cleanUpDatasets + cleanUpTables(TPCH_SCHEMA); + cleanUpTables(TEST_SCHEMA); + } + + private void cleanUpTables(String schemaName) { logObjectsCount(schemaName); if (!tablesToKeep.isEmpty()) { @@ -128,16 +137,6 @@ public void cleanUpTables(String schemaName) logObjectsCount(schemaName); } - @DataProvider - public static Object[][] cleanUpSchemasDataProvider() - { - // Other schemas created by tests are taken care of by cleanUpDatasets - return new Object[][] { - {TPCH_SCHEMA}, - {TEST_SCHEMA}, - }; - } - private void logObjectsCount(String schemaName) { TableResult result = bigQuerySqlExecutor.executeQuery(format("" + diff --git a/plugin/trino-blackhole/pom.xml b/plugin/trino-blackhole/pom.xml index 20f0ed9d04b6..09a55a4b5d40 100644 --- a/plugin/trino-blackhole/pom.xml +++ b/plugin/trino-blackhole/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-cassandra/pom.xml b/plugin/trino-cassandra/pom.xml index 4febcf90ba8e..546dcbf78a39 100644 --- a/plugin/trino-cassandra/pom.xml +++ b/plugin/trino-cassandra/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java index 74b7f816ce01..44bcedea1250 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java @@ -35,7 +35,6 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.RecordCursor; -import io.trino.spi.connector.SchemaNotFoundException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.predicate.Domain; @@ -48,9 +47,11 @@ import io.trino.testing.TestingConnectorContext; import io.trino.testing.TestingConnectorSession; import io.trino.type.IpAddressType; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.net.InetAddress; import java.net.UnknownHostException; @@ -86,11 +87,13 @@ import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; +@TestInstance(PER_CLASS) public class TestCassandraConnector { protected static final String INVALID_DATABASE = "totally_invalid_database"; @@ -108,7 +111,7 @@ public class TestCassandraConnector private ConnectorSplitManager splitManager; private ConnectorRecordSetProvider recordSetProvider; - @BeforeClass + @BeforeAll public void setup() throws Exception { @@ -142,7 +145,7 @@ public void setup() tableUdt = new SchemaTableName(database, TABLE_USER_DEFINED_TYPE.toLowerCase(ENGLISH)); } - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() { server.close(); @@ -162,8 +165,8 @@ public void testGetTableNames() assertTrue(tables.contains(table)); } - // disabled until metadata manager is updated to handle invalid catalogs and schemas - @Test(enabled = false, expectedExceptions = SchemaNotFoundException.class) + @Test + @Disabled // disabled until metadata manager is updated to handle invalid catalogs and schemas public void testGetTableNamesException() { metadata.listTables(SESSION, Optional.of(INVALID_DATABASE)); diff --git a/plugin/trino-clickhouse/pom.xml b/plugin/trino-clickhouse/pom.xml index 4bc92462c752..066975765820 100644 --- a/plugin/trino-clickhouse/pom.xml +++ b/plugin/trino-clickhouse/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-delta-lake/pom.xml b/plugin/trino-delta-lake/pom.xml index 899d79309bf3..4df45d2a80d3 100644 --- a/plugin/trino-delta-lake/pom.xml +++ b/plugin/trino-delta-lake/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index a10166cf0213..9326ad38bac9 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -504,7 +504,7 @@ private ReaderPageSource createParquetPageSource(Location path) dataColumns.stream() .map(DeltaLakeColumnHandle::toHiveColumnHandle) .collect(toImmutableList()), - TupleDomain.all(), + ImmutableList.of(TupleDomain.all()), true, parquetDateTimeZone, new FileFormatDataSourceStats(), diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 6145d123218f..a828fa47b2b3 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -1425,7 +1425,7 @@ public void addColumn(ConnectorSession session, ConnectorTableHandle tableHandle } checkUnsupportedWriterFeatures(protocolEntry); - if (!newColumnMetadata.isNullable() && !transactionLogAccess.getActiveFiles(getSnapshot(session, handle), session).isEmpty()) { + if (!newColumnMetadata.isNullable() && !transactionLogAccess.getActiveFiles(getSnapshot(session, handle), handle.getMetadataEntry(), handle.getProtocolEntry(), session).isEmpty()) { throw new TrinoException(DELTA_LAKE_BAD_WRITE, format("Unable to add NOT NULL column '%s' for non-empty table: %s.%s", newColumnMetadata.getName(), handle.getSchemaName(), handle.getTableName())); } @@ -3141,7 +3141,7 @@ public void finishStatisticsCollection(ConnectorSession session, ConnectorTableH private void generateMissingFileStatistics(ConnectorSession session, DeltaLakeTableHandle tableHandle, Collection computedStatistics) { Map addFileEntriesWithNoStats = transactionLogAccess.getActiveFiles( - getSnapshot(session, tableHandle), session) + getSnapshot(session, tableHandle), tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), session) .stream() .filter(addFileEntry -> addFileEntry.getStats().isEmpty() || addFileEntry.getStats().get().getNumRecords().isEmpty() @@ -3491,7 +3491,7 @@ private OptionalLong executeDelete(ConnectorSession session, ConnectorTableHandl private List getAddFileEntriesMatchingEnforcedPartitionConstraint(ConnectorSession session, DeltaLakeTableHandle tableHandle) { TableSnapshot tableSnapshot = getSnapshot(session, tableHandle); - List validDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, session); + List validDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), session); TupleDomain enforcedPartitionConstraint = tableHandle.getEnforcedPartitionConstraint(); if (enforcedPartitionConstraint.isAll()) { return validDataFiles; diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index 0d848a613ede..4a56b213cd53 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -230,7 +230,7 @@ public ConnectorPageSource createPageSource( split.getStart(), split.getLength(), hiveColumnHandles.build(), - parquetPredicate, + ImmutableList.of(parquetPredicate), true, parquetDateTimeZone, fileFormatDataSourceStats, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java index 26fd1eb2d153..c7acaf3d8533 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java @@ -156,7 +156,7 @@ private Stream getSplits( catch (IOException e) { throw new RuntimeException(e); } - List validDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, session); + List validDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), session); TupleDomain enforcedPartitionConstraint = tableHandle.getEnforcedPartitionConstraint(); TupleDomain nonPartitionConstraint = tableHandle.getNonPartitionConstraint(); Domain pathDomain = getPathDomain(nonPartitionConstraint); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index f485d3e0238e..d5cf1f482538 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -196,7 +196,7 @@ private static DeltaLakePageSource createDeltaLakePageSource( 0, split.fileSize(), splitColumns.stream().filter(column -> column.getColumnType() == REGULAR).map(DeltaLakeColumnHandle::toHiveColumnHandle).collect(toImmutableList()), - TupleDomain.all(), // TODO add predicate pushdown https://github.com/trinodb/trino/issues/16990 + ImmutableList.of(TupleDomain.all()), // TODO add predicate pushdown https://github.com/trinodb/trino/issues/16990 true, parquetDateTimeZone, fileFormatDataSourceStats, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/VacuumProcedure.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/VacuumProcedure.java index fc86f851d1d6..e751861558fb 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/VacuumProcedure.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/VacuumProcedure.java @@ -197,7 +197,7 @@ private void doVacuum( // Any remaining file are not live, and not needed to read any "recent" snapshot. List recentVersions = transactionLogAccess.getPastTableVersions(fileSystem, transactionLogDir, threshold, tableSnapshot.getVersion()); Set retainedPaths = Stream.concat( - transactionLogAccess.getActiveFiles(tableSnapshot, session).stream() + transactionLogAccess.getActiveFiles(tableSnapshot, handle.getMetadataEntry(), handle.getProtocolEntry(), session).stream() .map(AddFileEntry::getPath), transactionLogAccess.getJsonEntries( fileSystem, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/FileBasedTableStatisticsProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/FileBasedTableStatisticsProvider.java index ec080238e7ab..eddcfcb99f84 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/FileBasedTableStatisticsProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/FileBasedTableStatisticsProvider.java @@ -110,7 +110,7 @@ public TableStatistics getTableStatistics(ConnectorSession session, DeltaLakeTab .filter(column -> predicatedColumnNames.contains(column.getName())) .collect(toImmutableList()); - for (AddFileEntry addEntry : transactionLogAccess.getActiveFiles(tableSnapshot, session)) { + for (AddFileEntry addEntry : transactionLogAccess.getActiveFiles(tableSnapshot, tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), session)) { Optional fileStatistics = addEntry.getStats(); if (fileStatistics.isEmpty()) { // Open source Delta Lake does not collect stats diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeTransactionLogEntry.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeTransactionLogEntry.java index 3ddc2eff488c..cb86f1c99f4e 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeTransactionLogEntry.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/DeltaLakeTransactionLogEntry.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.annotation.Nullable; +import java.util.Objects; + import static java.util.Objects.requireNonNull; public class DeltaLakeTransactionLogEntry @@ -156,6 +158,31 @@ public DeltaLakeTransactionLogEntry withCommitInfo(CommitInfoEntry commitInfo) return new DeltaLakeTransactionLogEntry(txn, add, remove, metaData, protocol, commitInfo, cdcEntry); } + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DeltaLakeTransactionLogEntry that = (DeltaLakeTransactionLogEntry) o; + return Objects.equals(txn, that.txn) && + Objects.equals(add, that.add) && + Objects.equals(remove, that.remove) && + Objects.equals(metaData, that.metaData) && + Objects.equals(protocol, that.protocol) && + Objects.equals(commitInfo, that.commitInfo) && + Objects.equals(cdcEntry, that.cdcEntry); + } + + @Override + public int hashCode() + { + return Objects.hash(txn, add, remove, metaData, protocol, commitInfo, cdcEntry); + } + @Override public String toString() { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java index cafbc0fcea65..5069517239a8 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TableSnapshot.java @@ -14,7 +14,6 @@ package io.trino.plugin.deltalake.transactionlog; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoInputFile; @@ -37,14 +36,12 @@ import java.util.Set; import java.util.stream.Stream; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Streams.stream; -import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.readLastCheckpoint; import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.getTransactionLogDir; import static io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator.EntryType.ADD; -import static io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator.EntryType.METADATA; -import static io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator.EntryType.PROTOCOL; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -180,7 +177,8 @@ public Stream getCheckpointTransactionLogEntries( CheckpointSchemaManager checkpointSchemaManager, TypeManager typeManager, TrinoFileSystem fileSystem, - FileFormatDataSourceStats stats) + FileFormatDataSourceStats stats, + Optional metadataAndProtocol) throws IOException { if (lastCheckpoint.isEmpty()) { @@ -190,15 +188,8 @@ public Stream getCheckpointTransactionLogEntries( LastCheckpoint checkpoint = lastCheckpoint.get(); // Add entries contain statistics. When struct statistics are used the format of the Parquet file depends on the schema. It is important to use the schema at the time // of the Checkpoint creation, in case the schema has evolved since it was written. - Optional metadataAndProtocol = Optional.empty(); if (entryTypes.contains(ADD)) { - metadataAndProtocol = Optional.of(getCheckpointMetadataAndProtocolEntries( - session, - checkpointSchemaManager, - typeManager, - fileSystem, - stats, - checkpoint)); + checkState(metadataAndProtocol.isPresent(), "metadata and protocol information is needed to process the add log entries"); } Stream resultStream = Stream.empty(); @@ -259,51 +250,9 @@ private Iterator getCheckpointTransactionLogEntrie domainCompactionThreshold); } - private MetadataAndProtocolEntry getCheckpointMetadataAndProtocolEntries( - ConnectorSession session, - CheckpointSchemaManager checkpointSchemaManager, - TypeManager typeManager, - TrinoFileSystem fileSystem, - FileFormatDataSourceStats stats, - LastCheckpoint checkpoint) - throws IOException - { - MetadataEntry metadata = null; - ProtocolEntry protocol = null; - for (Location checkpointPath : getCheckpointPartPaths(checkpoint)) { - TrinoInputFile checkpointFile = fileSystem.newInputFile(checkpointPath); - Iterator entries = getCheckpointTransactionLogEntries( - session, - ImmutableSet.of(METADATA, PROTOCOL), - Optional.empty(), - Optional.empty(), - checkpointSchemaManager, - typeManager, - stats, - checkpoint, - checkpointFile); - while (entries.hasNext()) { - DeltaLakeTransactionLogEntry entry = entries.next(); - if (metadata == null && entry.getMetaData() != null) { - metadata = entry.getMetaData(); - } - if (protocol == null && entry.getProtocol() != null) { - protocol = entry.getProtocol(); - } - if (metadata != null && protocol != null) { - break; - } - } - } - if (metadata == null || protocol == null) { - throw new TrinoException(DELTA_LAKE_BAD_DATA, "Checkpoint found without metadata and protocol entry: " + checkpoint); - } - return new MetadataAndProtocolEntry(metadata, protocol); - } - - private record MetadataAndProtocolEntry(MetadataEntry metadataEntry, ProtocolEntry protocolEntry) + public record MetadataAndProtocolEntry(MetadataEntry metadataEntry, ProtocolEntry protocolEntry) { - private MetadataAndProtocolEntry + public MetadataAndProtocolEntry { requireNonNull(metadataEntry, "metadataEntry is null"); requireNonNull(protocolEntry, "protocolEntry is null"); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java index cb2c0d7b289f..36bc1a1c24a6 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/TransactionLogAccess.java @@ -30,6 +30,7 @@ import io.trino.parquet.ParquetReaderOptions; import io.trino.plugin.deltalake.DeltaLakeColumnMetadata; import io.trino.plugin.deltalake.DeltaLakeConfig; +import io.trino.plugin.deltalake.transactionlog.TableSnapshot.MetadataAndProtocolEntry; import io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator; import io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointSchemaManager; import io.trino.plugin.deltalake.transactionlog.checkpoint.LastCheckpoint; @@ -255,7 +256,7 @@ public MetadataEntry getMetadataEntry(TableSnapshot tableSnapshot, ConnectorSess .orElseThrow(() -> new TrinoException(DELTA_LAKE_INVALID_SCHEMA, "Metadata not found in transaction log for " + tableSnapshot.getTable())); } - public List getActiveFiles(TableSnapshot tableSnapshot, ConnectorSession session) + public List getActiveFiles(TableSnapshot tableSnapshot, MetadataEntry metadataEntry, ProtocolEntry protocolEntry, ConnectorSession session) { try { TableVersion tableVersion = new TableVersion(new TableLocation(tableSnapshot.getTable(), tableSnapshot.getTableLocation()), tableSnapshot.getVersion()); @@ -285,7 +286,7 @@ public List getActiveFiles(TableSnapshot tableSnapshot, ConnectorS } } - List activeFiles = loadActiveFiles(tableSnapshot, session); + List activeFiles = loadActiveFiles(tableSnapshot, metadataEntry, protocolEntry, session); return new DeltaLakeDataFileCacheEntry(tableSnapshot.getVersion(), activeFiles); }); return cacheEntry.getActiveFiles(); @@ -295,17 +296,22 @@ public List getActiveFiles(TableSnapshot tableSnapshot, ConnectorS } } - private List loadActiveFiles(TableSnapshot tableSnapshot, ConnectorSession session) + private List loadActiveFiles(TableSnapshot tableSnapshot, MetadataEntry metadataEntry, ProtocolEntry protocolEntry, ConnectorSession session) { - try (Stream entries = getEntries( - tableSnapshot, - ImmutableSet.of(ADD), - this::activeAddEntries, + List transactions = tableSnapshot.getTransactions(); + try (Stream checkpointEntries = tableSnapshot.getCheckpointTransactionLogEntries( session, + ImmutableSet.of(ADD), + checkpointSchemaManager, + typeManager, fileSystemFactory.create(session), - fileFormatDataSourceStats)) { - List activeFiles = entries.collect(toImmutableList()); - return activeFiles; + fileFormatDataSourceStats, + Optional.of(new MetadataAndProtocolEntry(metadataEntry, protocolEntry)))) { + return activeAddEntries(checkpointEntries, transactions) + .collect(toImmutableList()); + } + catch (IOException e) { + throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, "Error reading transaction log for " + tableSnapshot.getTable(), e); } } @@ -439,7 +445,7 @@ private Stream getEntries( try { List transactions = tableSnapshot.getTransactions(); Stream checkpointEntries = tableSnapshot.getCheckpointTransactionLogEntries( - session, entryTypes, checkpointSchemaManager, typeManager, fileSystem, stats); + session, entryTypes, checkpointSchemaManager, typeManager, fileSystem, stats, Optional.empty()); return entryMapper.apply( checkpointEntries, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index d97e86f19a04..1892b44643dd 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -17,6 +17,7 @@ import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.math.LongMath; import io.airlift.log.Logger; import io.trino.filesystem.TrinoInputFile; @@ -41,8 +42,16 @@ import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; import io.trino.spi.Page; import io.trino.spi.TrinoException; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.MapBlock; +import io.trino.spi.block.RowBlock; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.Domain; @@ -71,10 +80,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; -import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.isDeletionVectorEnabled; @@ -97,7 +103,6 @@ import static java.lang.Math.floorDiv; import static java.lang.String.format; import static java.math.RoundingMode.UNNECESSARY; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; public class CheckpointEntryIterator @@ -137,7 +142,8 @@ public String getColumnName() private final boolean checkpointRowStatisticsWritingEnabled; private MetadataEntry metadataEntry; private ProtocolEntry protocolEntry; - private List schema; // Use DeltaLakeColumnMetadata? + private List schema; + private List columnsWithMinMaxStats; private Page page; private long pageIndex; private int pagePosition; @@ -161,7 +167,7 @@ public CheckpointEntryIterator( this.stringList = (ArrayType) typeManager.getType(TypeSignature.arrayType(VARCHAR.getTypeSignature())); this.stringMap = (MapType) typeManager.getType(TypeSignature.mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature())); this.checkpointRowStatisticsWritingEnabled = checkpointRowStatisticsWritingEnabled; - checkArgument(fields.size() > 0, "fields is empty"); + checkArgument(!fields.isEmpty(), "fields is empty"); Map extractors = ImmutableMap.builder() .put(TRANSACTION, this::buildTxnEntry) .put(ADD, this::buildAddEntry) @@ -177,22 +183,23 @@ public CheckpointEntryIterator( checkArgument(protocolEntry.isPresent(), "Protocol entry must be provided when reading ADD entries from Checkpoint files"); this.protocolEntry = protocolEntry.get(); this.schema = extractSchema(this.metadataEntry, this.protocolEntry, typeManager); + this.columnsWithMinMaxStats = columnsWithStats(schema, this.metadataEntry.getOriginalPartitionColumns()); } - List columns = fields.stream() - .map(field -> buildColumnHandle(field, checkpointSchemaManager, this.metadataEntry, this.protocolEntry).toHiveColumnHandle()) - .collect(toImmutableList()); - - TupleDomain tupleDomain = columns.size() > 1 ? - TupleDomain.all() : - buildTupleDomainColumnHandle(getOnlyElement(fields), getOnlyElement(columns)); + ImmutableList.Builder columnsBuilder = ImmutableList.builderWithExpectedSize(fields.size()); + ImmutableList.Builder> disjunctDomainsBuilder = ImmutableList.builderWithExpectedSize(fields.size()); + for (EntryType field : fields) { + HiveColumnHandle column = buildColumnHandle(field, checkpointSchemaManager, this.metadataEntry, this.protocolEntry).toHiveColumnHandle(); + columnsBuilder.add(column); + disjunctDomainsBuilder.add(buildTupleDomainColumnHandle(field, column)); + } ReaderPageSource pageSource = ParquetPageSourceFactory.createPageSource( checkpoint, 0, fileSize, - columns, - tupleDomain, + columnsBuilder.build(), + disjunctDomainsBuilder.build(), // OR-ed condition true, DateTimeZone.UTC, stats, @@ -278,41 +285,40 @@ private DeltaLakeTransactionLogEntry buildCommitInfoEntry(ConnectorSession sessi int jobFields = 5; int notebookFields = 1; SqlRow commitInfoRow = block.getObject(pagePosition, SqlRow.class); - int commitInfoRawIndex = commitInfoRow.getRawIndex(); log.debug("Block %s has %s fields", block, commitInfoRow.getFieldCount()); if (commitInfoRow.getFieldCount() != commitInfoFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, commitInfoFields, commitInfoRow.getFieldCount())); } - SqlRow jobRow = commitInfoRow.getRawFieldBlock(9).getObject(commitInfoRawIndex, SqlRow.class); + SqlRow jobRow = getRowField(commitInfoRow, 9); if (jobRow.getFieldCount() != jobFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", jobRow, jobFields, jobRow.getFieldCount())); } - SqlRow notebookRow = commitInfoRow.getRawFieldBlock(7).getObject(commitInfoRawIndex, SqlRow.class); + SqlRow notebookRow = getRowField(commitInfoRow, 7); if (notebookRow.getFieldCount() != notebookFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", notebookRow, notebookFields, notebookRow.getFieldCount())); } CommitInfoEntry result = new CommitInfoEntry( - getLong(commitInfoRow.getRawFieldBlock(0), commitInfoRawIndex), - getLong(commitInfoRow.getRawFieldBlock(1), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(2), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(3), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(4), commitInfoRawIndex), - getMap(commitInfoRow.getRawFieldBlock(5), commitInfoRawIndex), + getLongField(commitInfoRow, 0), + getLongField(commitInfoRow, 1), + getStringField(commitInfoRow, 2), + getStringField(commitInfoRow, 3), + getStringField(commitInfoRow, 4), + getMapField(commitInfoRow, 5), new CommitInfoEntry.Job( - getString(jobRow.getRawFieldBlock(0), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(1), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(2), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(3), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(4), jobRow.getRawIndex())), + getStringField(jobRow, 0), + getStringField(jobRow, 1), + getStringField(jobRow, 2), + getStringField(jobRow, 3), + getStringField(jobRow, 4)), new CommitInfoEntry.Notebook( - getString(notebookRow.getRawFieldBlock(0), notebookRow.getRawIndex())), - getString(commitInfoRow.getRawFieldBlock(8), commitInfoRawIndex), - getLong(commitInfoRow.getRawFieldBlock(9), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(10), commitInfoRawIndex), - Optional.of(getByte(commitInfoRow.getRawFieldBlock(11), commitInfoRawIndex) != 0)); + getStringField(notebookRow, 0)), + getStringField(commitInfoRow, 8), + getLongField(commitInfoRow, 9), + getStringField(commitInfoRow, 10), + Optional.of(getBooleanField(commitInfoRow, 11))); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.commitInfoEntry(result); } @@ -332,15 +338,14 @@ private DeltaLakeTransactionLogEntry buildProtocolEntry(ConnectorSession session throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have between %d and %d children, but found %s", block, minProtocolFields, maxProtocolFields, fieldCount)); } - int rawIndex = protocolEntryRow.getRawIndex(); - Block readerFeaturesField = protocolEntryRow.getRawFieldBlock(2); + Optional> readerFeatures = getOptionalSetField(protocolEntryRow, 2); // The last entry should be writer feature when protocol entry size is 3 https://github.com/delta-io/delta/blob/master/PROTOCOL.md#disabled-features - Block writerFeaturesField = fieldCount != 4 ? readerFeaturesField : protocolEntryRow.getRawFieldBlock(3); + Optional> writerFeatures = fieldCount != 4 ? readerFeatures : getOptionalSetField(protocolEntryRow, 3); ProtocolEntry result = new ProtocolEntry( - getInt(protocolEntryRow.getRawFieldBlock(0), rawIndex), - getInt(protocolEntryRow.getRawFieldBlock(1), rawIndex), - readerFeaturesField.isNull(rawIndex) ? Optional.empty() : Optional.of(getList(readerFeaturesField, rawIndex).stream().collect(toImmutableSet())), - writerFeaturesField.isNull(rawIndex) ? Optional.empty() : Optional.of(getList(writerFeaturesField, rawIndex).stream().collect(toImmutableSet()))); + getIntField(protocolEntryRow, 0), + getIntField(protocolEntryRow, 1), + readerFeatures, + writerFeatures); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.protocolEntry(result); } @@ -354,28 +359,27 @@ private DeltaLakeTransactionLogEntry buildMetadataEntry(ConnectorSession session int metadataFields = 8; int formatFields = 2; SqlRow metadataEntryRow = block.getObject(pagePosition, SqlRow.class); - int rawIndex = metadataEntryRow.getRawIndex(); log.debug("Block %s has %s fields", block, metadataEntryRow.getFieldCount()); if (metadataEntryRow.getFieldCount() != metadataFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, metadataFields, metadataEntryRow.getFieldCount())); } - SqlRow formatRow = metadataEntryRow.getRawFieldBlock(3).getObject(rawIndex, SqlRow.class); + SqlRow formatRow = getRowField(metadataEntryRow, 3); if (formatRow.getFieldCount() != formatFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", formatRow, formatFields, formatRow.getFieldCount())); } MetadataEntry result = new MetadataEntry( - getString(metadataEntryRow.getRawFieldBlock(0), rawIndex), - getString(metadataEntryRow.getRawFieldBlock(1), rawIndex), - getString(metadataEntryRow.getRawFieldBlock(2), rawIndex), + getStringField(metadataEntryRow, 0), + getStringField(metadataEntryRow, 1), + getStringField(metadataEntryRow, 2), new MetadataEntry.Format( - getString(formatRow.getRawFieldBlock(0), formatRow.getRawIndex()), - getMap(formatRow.getRawFieldBlock(1), formatRow.getRawIndex())), - getString(metadataEntryRow.getRawFieldBlock(4), rawIndex), - getList(metadataEntryRow.getRawFieldBlock(5), rawIndex), - getMap(metadataEntryRow.getRawFieldBlock(6), rawIndex), - getLong(metadataEntryRow.getRawFieldBlock(7), rawIndex)); + getStringField(formatRow, 0), + getMapField(formatRow, 1)), + getStringField(metadataEntryRow, 4), + getListField(metadataEntryRow, 5), + getMapField(metadataEntryRow, 6), + getLongField(metadataEntryRow, 7)); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.metadataEntry(result); } @@ -393,11 +397,10 @@ private DeltaLakeTransactionLogEntry buildRemoveEntry(ConnectorSession session, throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, removeFields, removeEntryRow.getFieldCount())); } - int rawIndex = removeEntryRow.getRawIndex(); RemoveFileEntry result = new RemoveFileEntry( - getString(removeEntryRow.getRawFieldBlock(0), rawIndex), - getLong(removeEntryRow.getRawFieldBlock(1), rawIndex), - getByte(removeEntryRow.getRawFieldBlock(2), rawIndex) != 0); + getStringField(removeEntryRow, 0), + getLongField(removeEntryRow, 1), + getBooleanField(removeEntryRow, 2)); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.removeFileEntry(result); } @@ -411,98 +414,71 @@ private DeltaLakeTransactionLogEntry buildAddEntry(ConnectorSession session, Blo boolean deletionVectorsEnabled = isDeletionVectorEnabled(metadataEntry, protocolEntry); SqlRow addEntryRow = block.getObject(pagePosition, SqlRow.class); log.debug("Block %s has %s fields", block, addEntryRow.getFieldCount()); - int rawIndex = addEntryRow.getRawIndex(); - String path = getString(addEntryRow.getRawFieldBlock(0), rawIndex); - Map partitionValues = getMap(addEntryRow.getRawFieldBlock(1), rawIndex); - long size = getLong(addEntryRow.getRawFieldBlock(2), rawIndex); - long modificationTime = getLong(addEntryRow.getRawFieldBlock(3), rawIndex); - boolean dataChange = getByte(addEntryRow.getRawFieldBlock(4), rawIndex) != 0; + String path = getStringField(addEntryRow, 0); + Map partitionValues = getMapField(addEntryRow, 1); + long size = getLongField(addEntryRow, 2); + long modificationTime = getLongField(addEntryRow, 3); + boolean dataChange = getBooleanField(addEntryRow, 4); + Optional deletionVector = Optional.empty(); - int position = 5; + int statsFieldIndex; if (deletionVectorsEnabled) { - if (!addEntryRow.getRawFieldBlock(5).isNull(rawIndex)) { - deletionVector = Optional.of(parseDeletionVectorFromParquet(addEntryRow.getRawFieldBlock(5).getObject(rawIndex, Block.class))); - } - position = 6; - } - Map tags = getMap(addEntryRow.getRawFieldBlock(position + 2), rawIndex); - - AddFileEntry result; - if (!addEntryRow.getRawFieldBlock(position + 1).isNull(rawIndex)) { - result = new AddFileEntry( - path, - partitionValues, - size, - modificationTime, - dataChange, - Optional.empty(), - Optional.of(parseStatisticsFromParquet(addEntryRow.getRawFieldBlock(position + 1).getObject(rawIndex, SqlRow.class))), - tags, - deletionVector); - } - else if (!addEntryRow.getRawFieldBlock(position).isNull(rawIndex)) { - result = new AddFileEntry( - path, - partitionValues, - size, - modificationTime, - dataChange, - Optional.of(getString(addEntryRow.getRawFieldBlock(position), rawIndex)), - Optional.empty(), - tags, - deletionVector); + deletionVector = Optional.ofNullable(getRowField(addEntryRow, 5)).map(CheckpointEntryIterator::parseDeletionVectorFromParquet); + statsFieldIndex = 6; } else { - result = new AddFileEntry( - path, - partitionValues, - size, - modificationTime, - dataChange, - Optional.empty(), - Optional.empty(), - tags, - deletionVector); + statsFieldIndex = 5; + } + + Optional parsedStats = Optional.ofNullable(getRowField(addEntryRow, statsFieldIndex + 1)).map(this::parseStatisticsFromParquet); + Optional stats = Optional.empty(); + if (parsedStats.isEmpty()) { + stats = Optional.ofNullable(getStringField(addEntryRow, statsFieldIndex)); } + Map tags = getMapField(addEntryRow, statsFieldIndex + 2); + AddFileEntry result = new AddFileEntry( + path, + partitionValues, + size, + modificationTime, + dataChange, + stats, + parsedStats, + tags, + deletionVector); + log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.addFileEntry(result); } - private DeletionVectorEntry parseDeletionVectorFromParquet(Block block) + private static DeletionVectorEntry parseDeletionVectorFromParquet(SqlRow row) { - checkArgument(block.getPositionCount() == 5, "Deletion vector entry must have 5 fields"); + checkArgument(row.getFieldCount() == 5, "Deletion vector entry must have 5 fields"); - String storageType = getString(block, 0); - String pathOrInlineDv = getString(block, 1); - OptionalInt offset = block.isNull(2) ? OptionalInt.empty() : OptionalInt.of(getInt(block, 2)); - int sizeInBytes = getInt(block, 3); - long cardinality = getLong(block, 4); + String storageType = getStringField(row, 0); + String pathOrInlineDv = getStringField(row, 1); + OptionalInt offset = getOptionalIntField(row, 2); + int sizeInBytes = getIntField(row, 3); + long cardinality = getLongField(row, 4); return new DeletionVectorEntry(storageType, pathOrInlineDv, offset, sizeInBytes, cardinality); } private DeltaLakeParquetFileStatistics parseStatisticsFromParquet(SqlRow statsRow) { - if (metadataEntry == null) { - throw new TrinoException(DELTA_LAKE_BAD_DATA, "Checkpoint file found without metadata entry"); - } - // Block ordering is determined by TransactionLogAccess#buildAddColumnHandle, using the same method to ensure blocks are matched with the correct column - List columnsWithMinMaxStats = columnsWithStats(schema, metadataEntry.getOriginalPartitionColumns()); - - int rawIndex = statsRow.getRawIndex(); - long numRecords = getLong(statsRow.getRawFieldBlock(0), rawIndex); + long numRecords = getLongField(statsRow, 0); Optional> minValues = Optional.empty(); Optional> maxValues = Optional.empty(); Optional> nullCount; if (!columnsWithMinMaxStats.isEmpty()) { - minValues = Optional.of(readMinMax(statsRow.getRawFieldBlock(1), rawIndex, columnsWithMinMaxStats)); - maxValues = Optional.of(readMinMax(statsRow.getRawFieldBlock(2), rawIndex, columnsWithMinMaxStats)); - nullCount = Optional.of(readNullCount(statsRow.getRawFieldBlock(3), rawIndex, schema)); + minValues = Optional.of(parseMinMax(getRowField(statsRow, 1), columnsWithMinMaxStats)); + maxValues = Optional.of(parseMinMax(getRowField(statsRow, 2), columnsWithMinMaxStats)); + nullCount = Optional.of(parseNullCount(getRowField(statsRow, 3), schema)); } else { - nullCount = Optional.of(readNullCount(statsRow.getRawFieldBlock(1), rawIndex, schema)); + nullCount = Optional.of(parseNullCount(getRowField(statsRow, 1), schema)); } return new DeltaLakeParquetFileStatistics( @@ -512,71 +488,69 @@ private DeltaLakeParquetFileStatistics parseStatisticsFromParquet(SqlRow statsRo nullCount); } - private Map readMinMax(Block block, int blockPosition, List eligibleColumns) + private ImmutableMap parseMinMax(@Nullable SqlRow row, List eligibleColumns) { - if (block.isNull(blockPosition)) { + if (row == null) { // Statistics were not collected return ImmutableMap.of(); } - SqlRow row = block.getObject(blockPosition, SqlRow.class); ImmutableMap.Builder values = ImmutableMap.builder(); - int rawIndex = row.getRawIndex(); for (int i = 0; i < eligibleColumns.size(); i++) { DeltaLakeColumnMetadata metadata = eligibleColumns.get(i); String name = metadata.getPhysicalName(); Type type = metadata.getPhysicalColumnType(); - Block fieldBlock = row.getRawFieldBlock(i); - if (fieldBlock.isNull(rawIndex)) { + ValueBlock fieldBlock = row.getUnderlyingFieldBlock(i); + int fieldIndex = row.getUnderlyingFieldPosition(i); + if (fieldBlock.isNull(fieldIndex)) { continue; } if (type instanceof RowType rowType) { if (checkpointRowStatisticsWritingEnabled) { // RowType column statistics are not used for query planning, but need to be copied when writing out new Checkpoint files. - values.put(name, rowType.getObject(fieldBlock, rawIndex)); + values.put(name, rowType.getObject(fieldBlock, fieldIndex)); } continue; } if (type instanceof TimestampWithTimeZoneType) { - long epochMillis = LongMath.divide((long) readNativeValue(TIMESTAMP_MILLIS, fieldBlock, rawIndex), MICROSECONDS_PER_MILLISECOND, UNNECESSARY); + long epochMillis = LongMath.divide((long) readNativeValue(TIMESTAMP_MILLIS, fieldBlock, fieldIndex), MICROSECONDS_PER_MILLISECOND, UNNECESSARY); if (floorDiv(epochMillis, MILLISECONDS_PER_DAY) >= START_OF_MODERN_ERA_EPOCH_DAY) { values.put(name, packDateTimeWithZone(epochMillis, UTC_KEY)); } continue; } - values.put(name, readNativeValue(type, fieldBlock, rawIndex)); + values.put(name, readNativeValue(type, fieldBlock, fieldIndex)); } return values.buildOrThrow(); } - private Map readNullCount(Block block, int blockPosition, List columns) + private Map parseNullCount(SqlRow row, List columns) { - if (block.isNull(blockPosition)) { + if (row == null) { // Statistics were not collected return ImmutableMap.of(); } - SqlRow row = block.getObject(blockPosition, SqlRow.class); - int rawIndex = row.getRawIndex(); ImmutableMap.Builder values = ImmutableMap.builder(); for (int i = 0; i < columns.size(); i++) { DeltaLakeColumnMetadata metadata = columns.get(i); - Block fieldBlock = row.getRawFieldBlock(i); - if (fieldBlock.isNull(rawIndex)) { + ValueBlock fieldBlock = row.getUnderlyingFieldBlock(i); + int fieldIndex = row.getUnderlyingFieldPosition(i); + if (fieldBlock.isNull(fieldIndex)) { continue; } if (metadata.getType() instanceof RowType) { if (checkpointRowStatisticsWritingEnabled) { // RowType column statistics are not used for query planning, but need to be copied when writing out new Checkpoint files. - values.put(metadata.getPhysicalName(), fieldBlock.getObject(rawIndex, SqlRow.class)); + values.put(metadata.getPhysicalName(), fieldBlock.getObject(fieldIndex, SqlRow.class)); } continue; } - values.put(metadata.getPhysicalName(), getLong(fieldBlock, rawIndex)); + values.put(metadata.getPhysicalName(), getLongField(row, i)); } return values.buildOrThrow(); } @@ -594,52 +568,88 @@ private DeltaLakeTransactionLogEntry buildTxnEntry(ConnectorSession session, Blo throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, txnFields, txnEntryRow.getFieldCount())); } - int rawIndex = txnEntryRow.getRawIndex(); TransactionEntry result = new TransactionEntry( - getString(txnEntryRow.getRawFieldBlock(0), rawIndex), - getLong(txnEntryRow.getRawFieldBlock(1), rawIndex), - getLong(txnEntryRow.getRawFieldBlock(2), rawIndex)); + getStringField(txnEntryRow, 0), + getLongField(txnEntryRow, 1), + getLongField(txnEntryRow, 2)); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.transactionEntry(result); } @Nullable - private String getString(Block block, int position) + private static SqlRow getRowField(SqlRow row, int field) + { + RowBlock valueBlock = (RowBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { + return null; + } + return valueBlock.getRow(index); + } + + @Nullable + private static String getStringField(SqlRow row, int field) { - if (block.isNull(position)) { + VariableWidthBlock valueBlock = (VariableWidthBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { return null; } - return block.getSlice(position, 0, block.getSliceLength(position)).toString(UTF_8); + return valueBlock.getSlice(index).toStringUtf8(); } - private long getLong(Block block, int position) + private static long getLongField(SqlRow row, int field) { - checkArgument(!block.isNull(position)); - return block.getLong(position, 0); + LongArrayBlock valueBlock = (LongArrayBlock) row.getUnderlyingFieldBlock(field); + return valueBlock.getLong(row.getUnderlyingFieldPosition(field)); } - private int getInt(Block block, int position) + private static int getIntField(SqlRow row, int field) { - checkArgument(!block.isNull(position)); - return block.getInt(position, 0); + IntArrayBlock valueBlock = (IntArrayBlock) row.getUnderlyingFieldBlock(field); + return valueBlock.getInt(row.getUnderlyingFieldPosition(field)); } - private byte getByte(Block block, int position) + private static OptionalInt getOptionalIntField(SqlRow row, int field) { - checkArgument(!block.isNull(position)); - return block.getByte(position, 0); + IntArrayBlock valueBlock = (IntArrayBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { + return OptionalInt.empty(); + } + return OptionalInt.of(valueBlock.getInt(index)); + } + + private static boolean getBooleanField(SqlRow row, int field) + { + ByteArrayBlock valueBlock = (ByteArrayBlock) row.getUnderlyingFieldBlock(field); + return valueBlock.getByte(row.getUnderlyingFieldPosition(field)) != 0; + } + + @SuppressWarnings("unchecked") + private Map getMapField(SqlRow row, int field) + { + MapBlock valueBlock = (MapBlock) row.getUnderlyingFieldBlock(field); + return (Map) stringMap.getObjectValue(session, valueBlock, row.getUnderlyingFieldPosition(field)); } @SuppressWarnings("unchecked") - private Map getMap(Block block, int position) + private List getListField(SqlRow row, int field) { - return (Map) stringMap.getObjectValue(session, block, position); + ArrayBlock valueBlock = (ArrayBlock) row.getUnderlyingFieldBlock(field); + return (List) stringList.getObjectValue(session, valueBlock, row.getUnderlyingFieldPosition(field)); } @SuppressWarnings("unchecked") - private List getList(Block block, int position) + private Optional> getOptionalSetField(SqlRow row, int field) { - return (List) stringList.getObjectValue(session, block, position); + ArrayBlock valueBlock = (ArrayBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { + return Optional.empty(); + } + List list = (List) stringList.getObjectValue(session, valueBlock, index); + return Optional.of(ImmutableSet.copyOf(list)); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriterManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriterManager.java index 0e350ed1b5ab..44ebc5d96aed 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriterManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointWriterManager.java @@ -22,9 +22,11 @@ import io.trino.filesystem.TrinoOutputFile; import io.trino.plugin.deltalake.transactionlog.DeltaLakeTransactionLogEntry; import io.trino.plugin.deltalake.transactionlog.TableSnapshot; +import io.trino.plugin.deltalake.transactionlog.TableSnapshot.MetadataAndProtocolEntry; import io.trino.plugin.deltalake.transactionlog.TransactionLogAccess; import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.NodeVersion; +import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.TypeManager; @@ -32,11 +34,12 @@ import java.io.IOException; import java.io.OutputStream; import java.io.UncheckedIOException; +import java.util.List; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.MoreCollectors.toOptional; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.LAST_CHECKPOINT_FILENAME; import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.getTransactionLogDir; import static io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator.EntryType.ADD; @@ -92,16 +95,19 @@ public void writeCheckpoint(ConnectorSession session, TableSnapshot snapshot) CheckpointBuilder checkpointBuilder = new CheckpointBuilder(); TrinoFileSystem fileSystem = fileSystemFactory.create(session); - Optional checkpointMetadataLogEntry = snapshot + List checkpointLogEntries = snapshot .getCheckpointTransactionLogEntries( session, - ImmutableSet.of(METADATA), + ImmutableSet.of(METADATA, PROTOCOL), checkpointSchemaManager, typeManager, fileSystem, - fileFormatDataSourceStats) - .collect(toOptional()); - if (checkpointMetadataLogEntry.isPresent()) { + fileFormatDataSourceStats, + Optional.empty()) + .filter(entry -> entry.getMetaData() != null || entry.getProtocol() != null) + .collect(toImmutableList()); + + if (!checkpointLogEntries.isEmpty()) { // TODO HACK: this call is required only to ensure that cachedMetadataEntry is set in snapshot (https://github.com/trinodb/trino/issues/12032), // so we can read add entries below this should be reworked so we pass metadata entry explicitly to getCheckpointTransactionLogEntries, // and we should get rid of `setCachedMetadata` in TableSnapshot to make it immutable. @@ -109,17 +115,27 @@ public void writeCheckpoint(ConnectorSession session, TableSnapshot snapshot) transactionLogAccess.getMetadataEntry(snapshot, session); // register metadata entry in writer - checkState(checkpointMetadataLogEntry.get().getMetaData() != null, "metaData not present in log entry"); - checkpointBuilder.addLogEntry(checkpointMetadataLogEntry.get()); + DeltaLakeTransactionLogEntry metadataLogEntry = checkpointLogEntries.stream() + .filter(logEntry -> logEntry.getMetaData() != null) + .findFirst() + .orElseThrow(() -> new TrinoException(DELTA_LAKE_INVALID_SCHEMA, "Metadata not found in transaction log for " + snapshot.getTable())); + DeltaLakeTransactionLogEntry protocolLogEntry = checkpointLogEntries.stream() + .filter(logEntry -> logEntry.getProtocol() != null) + .findFirst() + .orElseThrow(() -> new TrinoException(DELTA_LAKE_INVALID_SCHEMA, "Protocol not found in transaction log for " + snapshot.getTable())); + + checkpointBuilder.addLogEntry(metadataLogEntry); + checkpointBuilder.addLogEntry(protocolLogEntry); // read remaining entries from checkpoint register them in writer snapshot.getCheckpointTransactionLogEntries( session, - ImmutableSet.of(PROTOCOL, TRANSACTION, ADD, REMOVE, COMMIT), + ImmutableSet.of(TRANSACTION, ADD, REMOVE, COMMIT), checkpointSchemaManager, typeManager, fileSystem, - fileFormatDataSourceStats) + fileFormatDataSourceStats, + Optional.of(new MetadataAndProtocolEntry(metadataLogEntry.getMetaData(), protocolLogEntry.getProtocol()))) .forEach(checkpointBuilder::addLogEntry); } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeCompatibility.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeCompatibility.java index f7be901886c1..3a9ed4155419 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeCompatibility.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeCompatibility.java @@ -18,8 +18,7 @@ import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; import io.trino.tpch.TpchTable; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG; import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.createS3DeltaLakeQueryRunner; @@ -69,19 +68,14 @@ protected QueryRunner createQueryRunner() return queryRunner; } - @Test(dataProvider = "tpchTablesDataProvider") - public void testSelectAll(String tableName) + @Test + public void testSelectAll() { - assertThat(query("SELECT * FROM " + tableName)) - .skippingTypesCheck() // Delta Lake connector returns varchar, but TPCH connector returns varchar(n) - .matches("SELECT * FROM tpch.tiny." + tableName); - } - - @DataProvider - public static Object[][] tpchTablesDataProvider() - { - return TpchTable.getTables().stream() - .map(table -> new Object[] {table.getTableName()}) - .toArray(Object[][]::new); + for (TpchTable table : TpchTable.getTables()) { + String tableName = table.getTableName(); + assertThat(query("SELECT * FROM " + tableName)) + .skippingTypesCheck() // Delta Lake connector returns varchar, but TPCH connector returns varchar(n) + .matches("SELECT * FROM tpch.tiny." + tableName); + } } } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsStorage.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsStorage.java index 8bfcf5f1582a..066260afe557 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsStorage.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsStorage.java @@ -19,11 +19,11 @@ import io.trino.plugin.hive.containers.HiveHadoop; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import org.testcontainers.containers.Network; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; import java.nio.file.Files; import java.nio.file.Path; @@ -43,7 +43,9 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestDeltaLakeAdlsStorage extends AbstractTestQueryFramework { @@ -58,15 +60,12 @@ public class TestDeltaLakeAdlsStorage private HiveHadoop hiveHadoop; - @Parameters({ - "hive.hadoop2.azure-abfs-container", - "hive.hadoop2.azure-abfs-account", - "hive.hadoop2.azure-abfs-access-key"}) - public TestDeltaLakeAdlsStorage(String container, String account, String accessKey) + public TestDeltaLakeAdlsStorage() { + String container = System.getProperty("hive.hadoop2.azure-abfs-container"); requireNonNull(container, "container is null"); - this.account = requireNonNull(account, "account is null"); - this.accessKey = requireNonNull(accessKey, "accessKey is null"); + this.account = requireNonNull(System.getProperty("hive.hadoop2.azure-abfs-account"), "account is null"); + this.accessKey = requireNonNull(System.getProperty("hive.hadoop2.azure-abfs-access-key"), "accessKey is null"); String directoryBase = format("abfs://%s@%s.dfs.core.windows.net", container, account); adlsDirectory = format("%s/tpch-tiny-%s/", directoryBase, randomUUID()); @@ -108,7 +107,7 @@ private Path createHadoopCoreSiteXmlTempFileWithAbfsSettings() return coreSiteXml; } - @BeforeClass(alwaysRun = true) + @BeforeAll public void setUp() { hiveHadoop.executeInContainerFailOnError("hadoop", "fs", "-mkdir", "-p", adlsDirectory); @@ -118,7 +117,7 @@ public void setUp() }); } - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() { if (adlsDirectory != null && hiveHadoop != null) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAnalyze.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAnalyze.java index 7950d7e3eda9..19fcdd41f38f 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAnalyze.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAnalyze.java @@ -21,10 +21,9 @@ import io.trino.plugin.deltalake.transactionlog.DeltaLakeTransactionLogEntry; import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeFileStatistics; import io.trino.testing.AbstractTestQueryFramework; -import io.trino.testing.DataProviders; import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.net.URI; @@ -766,8 +765,14 @@ public void testIncrementalStatisticsUpdateOnInsert() assertUpdate("DROP TABLE " + tableName); } - @Test(dataProviderClass = DataProviders.class, dataProvider = "trueFalse") - public void testCollectStatsAfterColumnAdded(boolean collectOnWrite) + @Test + public void testCollectStatsAfterColumnAdded() + { + testCollectStatsAfterColumnAdded(false); + testCollectStatsAfterColumnAdded(true); + } + + private void testCollectStatsAfterColumnAdded(boolean collectOnWrite) { String tableName = "test_collect_stats_after_column_added_" + randomNameSuffix(); assertUpdate("CREATE TABLE " + tableName + " (col_int_1 bigint, col_varchar_1 varchar)"); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index c3c45a1b7cfe..64f3e70d4de4 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -935,6 +935,22 @@ public void testStatsWithMinMaxValuesAsNulls() """); } + /** + * @see deltalake.multipart_checkpoint + */ + @Test + public void testReadMultipartCheckpoint() + throws Exception + { + String tableName = "test_multipart_checkpoint_" + randomNameSuffix(); + Path tableLocation = Files.createTempFile(tableName, null); + copyDirectoryContents(new File(Resources.getResource("deltalake/multipart_checkpoint").toURI()).toPath(), tableLocation); + + assertUpdate("CALL system.register_table('%s', '%s', '%s')".formatted(getSession().getSchema().orElseThrow(), tableName, tableLocation.toUri())); + assertThat(query("DESCRIBE " + tableName)).projected("Column", "Type").skippingTypesCheck().matches("VALUES ('c', 'integer')"); + assertThat(query("SELECT * FROM " + tableName)).matches("VALUES 1, 2, 3, 4, 5, 6, 7"); + } + private static MetadataEntry loadMetadataEntry(long entryNumber, Path tableLocation) throws IOException { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeColumnMapping.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeColumnMapping.java index 766977a308b7..8001bb889a27 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeColumnMapping.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeColumnMapping.java @@ -24,8 +24,7 @@ import io.trino.plugin.deltalake.transactionlog.MetadataEntry; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.nio.file.Path; @@ -42,7 +41,6 @@ import static io.trino.plugin.deltalake.transactionlog.checkpoint.TransactionLogTail.getEntriesFromJson; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_STATS; -import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.TestingNames.randomNameSuffix; import static org.assertj.core.api.Assertions.assertThat; @@ -60,22 +58,31 @@ protected QueryRunner createQueryRunner() return createDeltaLakeQueryRunner(DELTA_CATALOG, ImmutableMap.of(), ImmutableMap.of("delta.enable-non-concurrent-writes", "true")); } - @Test(dataProvider = "columnMappingModeDataProvider") - public void testCreateTableWithColumnMappingMode(String columnMappingMode) + @Test + public void testCreateTableWithColumnMappingMode() throws Exception { testCreateTableColumnMappingMode(tableName -> { - assertUpdate("CREATE TABLE " + tableName + "(a_int integer, a_row row(x integer)) WITH (column_mapping_mode='" + columnMappingMode + "')"); + assertUpdate("CREATE TABLE " + tableName + "(a_int integer, a_row row(x integer)) WITH (column_mapping_mode='id')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, row(11))", 1); + }); + + testCreateTableColumnMappingMode(tableName -> { + assertUpdate("CREATE TABLE " + tableName + "(a_int integer, a_row row(x integer)) WITH (column_mapping_mode='name')"); assertUpdate("INSERT INTO " + tableName + " VALUES (1, row(11))", 1); }); } - @Test(dataProvider = "columnMappingModeDataProvider") - public void testCreateTableAsSelectWithColumnMappingMode(String columnMappingMode) + @Test + public void testCreateTableAsSelectWithColumnMappingMode() throws Exception { testCreateTableColumnMappingMode(tableName -> - assertUpdate("CREATE TABLE " + tableName + " WITH (column_mapping_mode='" + columnMappingMode + "')" + + assertUpdate("CREATE TABLE " + tableName + " WITH (column_mapping_mode='id')" + + " AS SELECT 1 AS a_int, CAST(row(11) AS row(x integer)) AS a_row", 1)); + + testCreateTableColumnMappingMode(tableName -> + assertUpdate("CREATE TABLE " + tableName + " WITH (column_mapping_mode='name')" + " AS SELECT 1 AS a_int, CAST(row(11) AS row(x integer)) AS a_row", 1)); } @@ -116,13 +123,6 @@ private void testCreateTableColumnMappingMode(Consumer createTable) assertUpdate("DROP TABLE " + tableName); } - @DataProvider - public Object[][] columnMappingModeDataProvider() - { - return ImmutableList.of("id", "name").stream() - .collect(toDataProvider()); - } - private String getTableLocation(String tableName) { Pattern locationPattern = Pattern.compile(".*location = '(.*?)'.*", Pattern.DOTALL); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicFiltering.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicFiltering.java index 5e8c0d081862..647c88e48f4e 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicFiltering.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicFiltering.java @@ -37,8 +37,8 @@ import io.trino.testing.QueryRunner; import io.trino.transaction.TransactionId; import io.trino.transaction.TransactionManager; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import java.util.ArrayList; import java.util.List; @@ -46,7 +46,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; -import java.util.stream.Stream; import static com.google.common.base.Verify.verify; import static io.airlift.concurrent.MoreFutures.unmodifiableFuture; @@ -56,7 +55,6 @@ import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG; import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.createS3DeltaLakeQueryRunner; import static io.trino.spi.connector.Constraint.alwaysTrue; -import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.tpch.TpchTable.LINE_ITEM; import static io.trino.tpch.TpchTable.ORDERS; @@ -97,27 +95,24 @@ protected QueryRunner createQueryRunner() return queryRunner; } - @DataProvider - public Object[][] joinDistributionTypes() + @Test + @Timeout(60) + public void testDynamicFiltering() { - return Stream.of(JoinDistributionType.values()) - .collect(toDataProvider()); - } - - @Test(timeOut = 60_000, dataProvider = "joinDistributionTypes") - public void testDynamicFiltering(JoinDistributionType joinDistributionType) - { - String query = "SELECT * FROM lineitem JOIN orders ON lineitem.orderkey = orders.orderkey AND orders.totalprice > 59995 AND orders.totalprice < 60000"; - MaterializedResultWithQueryId filteredResult = getDistributedQueryRunner().executeWithQueryId(sessionWithDynamicFiltering(true, joinDistributionType), query); - MaterializedResultWithQueryId unfilteredResult = getDistributedQueryRunner().executeWithQueryId(sessionWithDynamicFiltering(false, joinDistributionType), query); - assertEqualsIgnoreOrder(filteredResult.getResult().getMaterializedRows(), unfilteredResult.getResult().getMaterializedRows()); - - QueryInputStats filteredStats = getQueryInputStats(filteredResult.getQueryId()); - QueryInputStats unfilteredStats = getQueryInputStats(unfilteredResult.getQueryId()); - assertGreaterThan(unfilteredStats.inputPositions, filteredStats.inputPositions); + for (JoinDistributionType joinDistributionType : JoinDistributionType.values()) { + String query = "SELECT * FROM lineitem JOIN orders ON lineitem.orderkey = orders.orderkey AND orders.totalprice > 59995 AND orders.totalprice < 60000"; + MaterializedResultWithQueryId filteredResult = getDistributedQueryRunner().executeWithQueryId(sessionWithDynamicFiltering(true, joinDistributionType), query); + MaterializedResultWithQueryId unfilteredResult = getDistributedQueryRunner().executeWithQueryId(sessionWithDynamicFiltering(false, joinDistributionType), query); + assertEqualsIgnoreOrder(filteredResult.getResult().getMaterializedRows(), unfilteredResult.getResult().getMaterializedRows()); + + QueryInputStats filteredStats = getQueryInputStats(filteredResult.getQueryId()); + QueryInputStats unfilteredStats = getQueryInputStats(unfilteredResult.getQueryId()); + assertGreaterThan(unfilteredStats.inputPositions, filteredStats.inputPositions); + } } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testIncompleteDynamicFilterTimeout() throws Exception { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicPartitionPruningTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicPartitionPruningTest.java index 34b6f6a4a6ea..fe6e43fb92fd 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicPartitionPruningTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeDynamicPartitionPruningTest.java @@ -18,7 +18,7 @@ import io.trino.testing.DistributedQueryRunner; import io.trino.testing.QueryRunner; import io.trino.tpch.TpchTable; -import org.testng.SkipException; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.io.UncheckedIOException; @@ -30,6 +30,7 @@ import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.createDeltaLakeQueryRunner; import static java.lang.String.format; import static java.util.stream.Collectors.joining; +import static org.junit.jupiter.api.Assumptions.abort; public class TestDeltaLakeDynamicPartitionPruningTest extends BaseDynamicPartitionPruningTest @@ -47,10 +48,11 @@ protected QueryRunner createQueryRunner() return queryRunner; } + @Test @Override public void testJoinDynamicFilteringMultiJoinOnBucketedTables() { - throw new SkipException("Delta Lake does not support bucketing"); + abort("Delta Lake does not support bucketing"); } @Override diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java index 71065d635569..8eb0290c1c4d 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java @@ -16,6 +16,7 @@ import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableMultiset; import com.google.common.collect.Multiset; +import com.google.common.io.Resources; import io.trino.Session; import io.trino.SystemSessionProperties; import io.trino.filesystem.TrackingFileSystemFactory; @@ -30,6 +31,8 @@ import java.io.File; import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Arrays; import java.util.Map; import java.util.Optional; @@ -51,6 +54,7 @@ import static io.trino.plugin.deltalake.TestDeltaLakeFileOperations.FileType.STARBURST_EXTENDED_STATS_JSON; import static io.trino.plugin.deltalake.TestDeltaLakeFileOperations.FileType.TRANSACTION_LOG_JSON; import static io.trino.plugin.deltalake.TestDeltaLakeFileOperations.FileType.TRINO_EXTENDED_STATS_JSON; +import static io.trino.plugin.deltalake.TestingDeltaLakeUtils.copyDirectoryContents; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_STATS; import static io.trino.testing.MultisetAssertions.assertMultisetsEqual; @@ -95,7 +99,8 @@ protected DistributedQueryRunner createQueryRunner() Map.of( "hive.metastore", "file", "hive.metastore.catalog.dir", metastoreDirectory, - "delta.enable-non-concurrent-writes", "true")); + "delta.enable-non-concurrent-writes", "true", + "delta.register-table-procedure.enabled", "true")); queryRunner.execute("CREATE SCHEMA " + session.getSchema().orElseThrow()); return queryRunner; @@ -189,8 +194,8 @@ public void testReadTableCheckpointInterval() "TABLE test_read_checkpoint", ImmutableMultiset.builder() .add(new FileOperation(LAST_CHECKPOINT, "_last_checkpoint", INPUT_FILE_NEW_STREAM)) - .addCopies(new FileOperation(CHECKPOINT, "00000000000000000002.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 6) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query - .addCopies(new FileOperation(CHECKPOINT, "00000000000000000002.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 3) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query + .addCopies(new FileOperation(CHECKPOINT, "00000000000000000002.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query + .addCopies(new FileOperation(CHECKPOINT, "00000000000000000002.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query .add(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000003.json", INPUT_FILE_NEW_STREAM)) .addCopies(new FileOperation(DATA, "no partition", INPUT_FILE_NEW_STREAM), 2) .build()); @@ -701,6 +706,28 @@ public void testShowTables() assertFileSystemAccesses("SHOW TABLES", ImmutableMultiset.of()); } + @Test + public void testReadMultipartCheckpoint() + throws Exception + { + String tableName = "test_multipart_checkpoint_" + randomNameSuffix(); + Path tableLocation = Files.createTempFile(tableName, null); + copyDirectoryContents(new File(Resources.getResource("deltalake/multipart_checkpoint").toURI()).toPath(), tableLocation); + + assertUpdate("CALL system.register_table('%s', '%s', '%s')".formatted(getSession().getSchema().orElseThrow(), tableName, tableLocation.toUri())); + assertFileSystemAccesses("SELECT * FROM " + tableName, + ImmutableMultiset.builder() + .add(new FileOperation(LAST_CHECKPOINT, "_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation(CHECKPOINT, "00000000000000000006.checkpoint.0000000001.0000000002.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query + .addCopies(new FileOperation(CHECKPOINT, "00000000000000000006.checkpoint.0000000001.0000000002.parquet", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query + .addCopies(new FileOperation(CHECKPOINT, "00000000000000000006.checkpoint.0000000002.0000000002.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query + .addCopies(new FileOperation(CHECKPOINT, "00000000000000000006.checkpoint.0000000002.0000000002.parquet", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/18916) should be checked once per query + .add(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000007.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000008.json", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation(DATA, "no partition", INPUT_FILE_NEW_STREAM), 7) + .build()); + } + private int countCdfFilesForKey(String partitionValue) { String path = (String) computeScalar("SELECT \"$path\" FROM table_changes_file_system_access WHERE key = '" + partitionValue + "'"); @@ -742,12 +769,11 @@ private record FileOperation(FileType fileType, String fileId, OperationType ope { public static FileOperation create(String path, OperationType operationType) { - Pattern dataFilePattern = Pattern.compile(".*?/(?key=[^/]*/)?(?\\d{8}_\\d{6}_\\d{5}_\\w{5})_(?[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})"); String fileName = path.replaceFirst(".*/", ""); if (path.matches(".*/_delta_log/_last_checkpoint")) { return new FileOperation(LAST_CHECKPOINT, fileName, operationType); } - if (path.matches(".*/_delta_log/\\d+\\.checkpoint\\.parquet")) { + if (path.matches(".*/_delta_log/\\d+\\.checkpoint(\\.\\d+\\.\\d+)?\\.parquet")) { return new FileOperation(CHECKPOINT, fileName, operationType); } if (path.matches(".*/_delta_log/\\d+\\.json")) { @@ -759,6 +785,7 @@ public static FileOperation create(String path, OperationType operationType) if (path.matches(".*/_delta_log/_starburst_meta/extendeded_stats.json")) { return new FileOperation(STARBURST_EXTENDED_STATS_JSON, fileName, operationType); } + Pattern dataFilePattern = Pattern.compile(".*?/(?key=[^/]*/)?[^/]+"); if (path.matches(".*/_change_data/.*")) { Matcher matcher = dataFilePattern.matcher(path); if (matcher.matches()) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeMetadata.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeMetadata.java index 99c1e9afd2bf..267eb2f41ffc 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeMetadata.java @@ -63,10 +63,10 @@ import io.trino.spi.type.VarcharType; import io.trino.testing.TestingConnectorContext; import io.trino.tests.BogusType; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.io.File; import java.io.IOException; @@ -96,7 +96,9 @@ import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestDeltaLakeMetadata { private static final String DATABASE_NAME = "mock_database"; @@ -204,7 +206,7 @@ public class TestDeltaLakeMetadata private File temporaryCatalogDirectory; private DeltaLakeMetadataFactory deltaLakeMetadataFactory; - @BeforeClass + @BeforeAll public void setUp() throws IOException { @@ -259,7 +261,7 @@ public DeltaLakeMetastore getDeltaLakeMetastore(@RawHiveMetastoreFactory HiveMet .build()); } - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() throws Exception { @@ -374,74 +376,57 @@ public void testGetInsertLayoutTableUnpartitioned() .isNotPresent(); } - @DataProvider - public Object[][] testApplyProjectionProvider() + @Test + public void testApplyProjection() { - return new Object[][] { - { - ImmutableSet.of(), - SYNTHETIC_COLUMN_ASSIGNMENTS, - SIMPLE_COLUMN_PROJECTIONS, - SIMPLE_COLUMN_PROJECTIONS, - ImmutableSet.of(BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), - SYNTHETIC_COLUMN_ASSIGNMENTS - }, - { - // table handle already contains subset of expected projected columns - ImmutableSet.of(BOGUS_COLUMN_HANDLE), - SYNTHETIC_COLUMN_ASSIGNMENTS, - SIMPLE_COLUMN_PROJECTIONS, - SIMPLE_COLUMN_PROJECTIONS, - ImmutableSet.of(BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), - SYNTHETIC_COLUMN_ASSIGNMENTS - }, - { - // table handle already contains superset of expected projected columns - ImmutableSet.of(DOUBLE_COLUMN_HANDLE, BOOLEAN_COLUMN_HANDLE, DATE_COLUMN_HANDLE, BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), - SYNTHETIC_COLUMN_ASSIGNMENTS, - SIMPLE_COLUMN_PROJECTIONS, - SIMPLE_COLUMN_PROJECTIONS, - ImmutableSet.of(BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), - SYNTHETIC_COLUMN_ASSIGNMENTS - }, - { - // table handle has empty assignments - ImmutableSet.of(DOUBLE_COLUMN_HANDLE, BOOLEAN_COLUMN_HANDLE, DATE_COLUMN_HANDLE, BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), - ImmutableMap.of(), - SIMPLE_COLUMN_PROJECTIONS, - SIMPLE_COLUMN_PROJECTIONS, - ImmutableSet.of(), - ImmutableMap.of() - }, - { - ImmutableSet.of(DOUBLE_COLUMN_HANDLE, BOOLEAN_COLUMN_HANDLE, DATE_COLUMN_HANDLE, BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), - ImmutableMap.of(), - DEREFERENCE_COLUMN_PROJECTIONS, - DEREFERENCE_COLUMN_PROJECTIONS, - ImmutableSet.of(), - ImmutableMap.of() - }, - { - ImmutableSet.of(NESTED_COLUMN_HANDLE), - NESTED_COLUMN_ASSIGNMENTS, - NESTED_DEREFERENCE_COLUMN_PROJECTIONS, - EXPECTED_NESTED_DEREFERENCE_COLUMN_PROJECTIONS, - ImmutableSet.of(EXPECTED_NESTED_COLUMN_HANDLE), - EXPECTED_NESTED_COLUMN_ASSIGNMENTS - }, - { - ImmutableSet.of(HIGHLY_NESTED_ROW_FIELD), - HIGHLY_NESTED_COLUMN_ASSIGNMENTS, - HIGHLY_NESTED_DEREFERENCE_COLUMN_PROJECTIONS, - EXPECTED_HIGHLY_NESTED_DEREFERENCE_COLUMN_PROJECTIONS, - ImmutableSet.of(EXPECTED_NESTED_COLUMN_HANDLE_WITH_PROJECTION), - EXPECTED_HIGHLY_NESTED_COLUMN_ASSIGNMENTS - } - }; + testApplyProjection( + ImmutableSet.of(), + SYNTHETIC_COLUMN_ASSIGNMENTS, + SIMPLE_COLUMN_PROJECTIONS, + SIMPLE_COLUMN_PROJECTIONS, + ImmutableSet.of(BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), + SYNTHETIC_COLUMN_ASSIGNMENTS); + testApplyProjection( + // table handle already contains subset of expected projected columns + ImmutableSet.of(BOGUS_COLUMN_HANDLE), + SYNTHETIC_COLUMN_ASSIGNMENTS, + SIMPLE_COLUMN_PROJECTIONS, + SIMPLE_COLUMN_PROJECTIONS, + ImmutableSet.of(BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), + SYNTHETIC_COLUMN_ASSIGNMENTS); + testApplyProjection( + // table handle already contains superset of expected projected columns + ImmutableSet.of(DOUBLE_COLUMN_HANDLE, BOOLEAN_COLUMN_HANDLE, DATE_COLUMN_HANDLE, BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), + SYNTHETIC_COLUMN_ASSIGNMENTS, + SIMPLE_COLUMN_PROJECTIONS, + SIMPLE_COLUMN_PROJECTIONS, + ImmutableSet.of(BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), + SYNTHETIC_COLUMN_ASSIGNMENTS); + testApplyProjection( + // table handle has empty assignments + ImmutableSet.of(DOUBLE_COLUMN_HANDLE, BOOLEAN_COLUMN_HANDLE, DATE_COLUMN_HANDLE, BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), + ImmutableMap.of(), + SIMPLE_COLUMN_PROJECTIONS, + SIMPLE_COLUMN_PROJECTIONS, + ImmutableSet.of(), + ImmutableMap.of()); + testApplyProjection( + ImmutableSet.of(DOUBLE_COLUMN_HANDLE, BOOLEAN_COLUMN_HANDLE, DATE_COLUMN_HANDLE, BOGUS_COLUMN_HANDLE, VARCHAR_COLUMN_HANDLE), + ImmutableMap.of(), + DEREFERENCE_COLUMN_PROJECTIONS, + DEREFERENCE_COLUMN_PROJECTIONS, + ImmutableSet.of(), + ImmutableMap.of()); + testApplyProjection( + ImmutableSet.of(NESTED_COLUMN_HANDLE), + NESTED_COLUMN_ASSIGNMENTS, + NESTED_DEREFERENCE_COLUMN_PROJECTIONS, + EXPECTED_NESTED_DEREFERENCE_COLUMN_PROJECTIONS, + ImmutableSet.of(EXPECTED_NESTED_COLUMN_HANDLE), + EXPECTED_NESTED_COLUMN_ASSIGNMENTS); } - @Test(dataProvider = "testApplyProjectionProvider") - public void testApplyProjection( + private void testApplyProjection( Set inputProjectedColumns, Map inputAssignments, List inputProjections, diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java index d90f057ecdfe..39b35a2be7ec 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java @@ -180,7 +180,7 @@ public TableSnapshot getSnapshot(ConnectorSession session, SchemaTableName table } @Override - public List getActiveFiles(TableSnapshot tableSnapshot, ConnectorSession session) + public List getActiveFiles(TableSnapshot tableSnapshot, MetadataEntry metadataEntry, ProtocolEntry protocolEntry, ConnectorSession session) { return addFileEntries; } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestTransactionLogAccess.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestTransactionLogAccess.java index f8effbbcb2e3..d2ef7fd718ab 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestTransactionLogAccess.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestTransactionLogAccess.java @@ -200,7 +200,9 @@ public void testGetActiveAddEntries() { setupTransactionLogAccessFromResources("person", "databricks73/person"); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); Set paths = addFileEntries .stream() .map(AddFileEntry::getPath) @@ -230,7 +232,9 @@ public void testAddFileEntryUppercase() { setupTransactionLogAccessFromResources("uppercase_columns", "databricks73/uppercase_columns"); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); AddFileEntry addFileEntry = addFileEntries .stream() .filter(entry -> entry.getPath().equals("ALA=1/part-00000-20a863e0-890d-4776-8825-f9dccc8973ba.c000.snappy.parquet")) @@ -252,7 +256,9 @@ public void testAddEntryPruning() // - Added in the parquet checkpoint but removed in a JSON commit // - Added in a JSON commit and removed in a later JSON commit setupTransactionLogAccessFromResources("person_test_pruning", "databricks73/person_test_pruning"); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); Set paths = addFileEntries .stream() .map(AddFileEntry::getPath) @@ -267,7 +273,9 @@ public void testAddEntryOverrides() throws Exception { setupTransactionLogAccessFromResources("person_test_pruning", "databricks73/person_test_pruning"); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); // Test data contains two entries which are added multiple times, the most up to date one should be the only one in the active list List overwrittenPaths = ImmutableList.of( @@ -288,7 +296,9 @@ public void testAddRemoveAdd() throws Exception { setupTransactionLogAccessFromResources("person_test_pruning", "databricks73/person_test_pruning"); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); // Test data contains an entry added by the parquet checkpoint, removed by a JSON action, and then added back by a later JSON action List activeEntries = addFileEntries.stream() @@ -366,8 +376,9 @@ public void testAllGetActiveAddEntries(String tableName, String resourcePath) throws Exception { setupTransactionLogAccessFromResources(tableName, resourcePath); - - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); Set paths = addFileEntries .stream() .map(AddFileEntry::getPath) @@ -445,8 +456,9 @@ public void testUpdatingTailEntriesNoCheckpoint() File resourceDir = new File(getClass().getClassLoader().getResource("databricks73/person/_delta_log").toURI()); copyTransactionLogEntry(0, 7, resourceDir, transactionLogDir); setupTransactionLogAccess(tableName, tableDir.toURI().toString()); - - List activeDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List activeDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); Set dataFiles = ImmutableSet.of( "age=42/part-00000-b82d8859-84a0-4f05-872c-206b07dd54f0.c000.snappy.parquet", @@ -460,7 +472,7 @@ public void testUpdatingTailEntriesNoCheckpoint() copyTransactionLogEntry(7, 9, resourceDir, transactionLogDir); TableSnapshot updatedSnapshot = transactionLogAccess.getSnapshot(SESSION, new SchemaTableName("schema", tableName), tableDir.toURI().toString(), Optional.empty()); - activeDataFiles = transactionLogAccess.getActiveFiles(updatedSnapshot, SESSION); + activeDataFiles = transactionLogAccess.getActiveFiles(updatedSnapshot, metadataEntry, protocolEntry, SESSION); transactionLogAccess.cleanupQuery(SESSION); dataFiles = ImmutableSet.of( @@ -489,7 +501,9 @@ public void testLoadingTailEntriesPastCheckpoint() copyTransactionLogEntry(0, 8, resourceDir, transactionLogDir); setupTransactionLogAccess(tableName, tableDir.toURI().toString()); - List activeDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List activeDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); transactionLogAccess.cleanupQuery(SESSION); Set dataFiles = ImmutableSet.of( @@ -505,7 +519,7 @@ public void testLoadingTailEntriesPastCheckpoint() copyTransactionLogEntry(8, 12, resourceDir, transactionLogDir); Files.copy(new File(resourceDir, LAST_CHECKPOINT_FILENAME).toPath(), new File(transactionLogDir, LAST_CHECKPOINT_FILENAME).toPath()); TableSnapshot updatedSnapshot = transactionLogAccess.getSnapshot(SESSION, new SchemaTableName("schema", tableName), tableDir.toURI().toString(), Optional.empty()); - activeDataFiles = transactionLogAccess.getActiveFiles(updatedSnapshot, SESSION); + activeDataFiles = transactionLogAccess.getActiveFiles(updatedSnapshot, metadataEntry, protocolEntry, SESSION); transactionLogAccess.cleanupQuery(SESSION); dataFiles = ImmutableSet.of( @@ -536,6 +550,8 @@ public void testIncrementalCacheUpdates() File resourceDir = new File(getClass().getClassLoader().getResource("databricks73/person/_delta_log").toURI()); copyTransactionLogEntry(0, 12, resourceDir, transactionLogDir); Files.copy(new File(resourceDir, LAST_CHECKPOINT_FILENAME).toPath(), new File(transactionLogDir, LAST_CHECKPOINT_FILENAME).toPath()); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); Set originalDataFiles = ImmutableSet.of( "age=42/part-00000-b26c891a-7288-4d96-9d3b-bef648f12a34.c000.snappy.parquet", @@ -552,15 +568,15 @@ public void testIncrementalCacheUpdates() assertFileSystemAccesses( () -> { setupTransactionLogAccess(tableName, tableDir.toURI().toString()); - List activeDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List activeDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEqualsIgnoreOrder(activeDataFiles.stream().map(AddFileEntry::getPath).collect(Collectors.toSet()), originalDataFiles); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 2) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? + .add(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) .build()); copyTransactionLogEntry(12, 14, resourceDir, transactionLogDir); @@ -570,15 +586,15 @@ public void testIncrementalCacheUpdates() assertFileSystemAccesses( () -> { TableSnapshot updatedTableSnapshot = transactionLogAccess.getSnapshot(SESSION, new SchemaTableName("schema", tableName), tableDir.toURI().toString(), Optional.empty()); - List activeDataFiles = transactionLogAccess.getActiveFiles(updatedTableSnapshot, SESSION); + List activeDataFiles = transactionLogAccess.getActiveFiles(updatedTableSnapshot, metadataEntry, protocolEntry, SESSION); transactionLogAccess.cleanupQuery(SESSION); assertEqualsIgnoreOrder(activeDataFiles.stream().map(AddFileEntry::getPath).collect(Collectors.toSet()), union(originalDataFiles, newDataFiles)); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 2) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) .build()); } @@ -597,7 +613,9 @@ public void testSnapshotsAreConsistent() Files.copy(new File(resourceDir, LAST_CHECKPOINT_FILENAME).toPath(), new File(transactionLogDir, LAST_CHECKPOINT_FILENAME).toPath()); setupTransactionLogAccess(tableName, tableDir.toURI().toString()); - List expectedDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List expectedDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); transactionLogAccess.cleanupQuery(SESSION); copyTransactionLogEntry(12, 14, resourceDir, transactionLogDir); @@ -605,8 +623,8 @@ public void testSnapshotsAreConsistent() "age=28/part-00000-40dd1707-1d42-4328-a59a-21f5c945fe60.c000.snappy.parquet", "age=29/part-00000-3794c463-cb0c-4beb-8d07-7cc1e3b5920f.c000.snappy.parquet"); TableSnapshot updatedTableSnapshot = transactionLogAccess.getSnapshot(SESSION, new SchemaTableName("schema", tableName), tableDir.toURI().toString(), Optional.empty()); - List allDataFiles = transactionLogAccess.getActiveFiles(updatedTableSnapshot, SESSION); - List dataFilesWithFixedVersion = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List allDataFiles = transactionLogAccess.getActiveFiles(updatedTableSnapshot, metadataEntry, protocolEntry, SESSION); + List dataFilesWithFixedVersion = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); for (String newFilePath : newDataFiles) { assertTrue(allDataFiles.stream().anyMatch(entry -> entry.getPath().equals(newFilePath))); assertTrue(dataFilesWithFixedVersion.stream().noneMatch(entry -> entry.getPath().equals(newFilePath))); @@ -679,7 +697,9 @@ public void testParquetStructStatistics() String tableName = "parquet_struct_statistics"; setupTransactionLogAccess(tableName, getClass().getClassLoader().getResource("databricks73/pruning/" + tableName).toURI().toString()); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); AddFileEntry addFileEntry = addFileEntries.stream() .filter(entry -> entry.getPath().equalsIgnoreCase("part-00000-0e22455f-5650-442f-a094-e1a8b7ed2271-c000.snappy.parquet")) @@ -730,11 +750,11 @@ public void testTableSnapshotsCacheDisabled() setupTransactionLogAccess(tableName, tableDir, cacheDisabledConfig); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) .build()); // With the transaction log cache disabled, when loading the snapshot again, all the needed files will be opened again @@ -743,11 +763,11 @@ public void testTableSnapshotsCacheDisabled() transactionLogAccess.getSnapshot(SESSION, new SchemaTableName("schema", tableName), tableDir, Optional.empty()); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) .build()); } @@ -759,27 +779,29 @@ public void testTableSnapshotsActiveDataFilesCache() String tableDir = getClass().getClassLoader().getResource("databricks73/" + tableName).toURI().toString(); DeltaLakeConfig shortLivedActiveDataFilesCacheConfig = new DeltaLakeConfig(); shortLivedActiveDataFilesCacheConfig.setDataFileCacheTtl(new Duration(10, TimeUnit.MINUTES)); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); assertFileSystemAccesses( () -> { setupTransactionLogAccess(tableName, tableDir, shortLivedActiveDataFilesCacheConfig); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEquals(addFileEntries.size(), 12); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 2) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? + .add(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM)) .build()); // The internal data cache should still contain the data files for the table assertFileSystemAccesses( () -> { - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEquals(addFileEntries.size(), 12); }, ImmutableMultiset.of()); @@ -793,21 +815,23 @@ public void testFlushSnapshotAndActiveFileCache() String tableDir = getClass().getClassLoader().getResource("databricks73/" + tableName).toURI().toString(); DeltaLakeConfig shortLivedActiveDataFilesCacheConfig = new DeltaLakeConfig(); shortLivedActiveDataFilesCacheConfig.setDataFileCacheTtl(new Duration(10, TimeUnit.MINUTES)); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); assertFileSystemAccesses( () -> { setupTransactionLogAccess(tableName, tableDir, shortLivedActiveDataFilesCacheConfig); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEquals(addFileEntries.size(), 12); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 2) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? + .add(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM)) .build()); // Flush all cache and then load snapshot and get active files @@ -815,17 +839,17 @@ public void testFlushSnapshotAndActiveFileCache() assertFileSystemAccesses( () -> { transactionLogAccess.getSnapshot(SESSION, new SchemaTableName("schema", tableName), tableDir, Optional.empty()); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEquals(addFileEntries.size(), 12); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 2) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? + .add(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM)) .build()); } @@ -837,33 +861,35 @@ public void testTableSnapshotsActiveDataFilesCacheDisabled() String tableDir = getClass().getClassLoader().getResource("databricks73/" + tableName).toURI().toString(); DeltaLakeConfig shortLivedActiveDataFilesCacheConfig = new DeltaLakeConfig(); shortLivedActiveDataFilesCacheConfig.setDataFileCacheTtl(new Duration(0, TimeUnit.SECONDS)); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); assertFileSystemAccesses( () -> { setupTransactionLogAccess(tableName, tableDir, shortLivedActiveDataFilesCacheConfig); - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEquals(addFileEntries.size(), 12); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) + .add(new FileOperation("_last_checkpoint", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000011.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000012.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000013.json", INPUT_FILE_NEW_STREAM)) + .add(new FileOperation("00000000000000000014.json", INPUT_FILE_NEW_STREAM)) + .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 2) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? + .add(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM)) .build()); // With no caching for the transaction log entries, when loading the snapshot again, // the checkpoint file will be read again assertFileSystemAccesses( () -> { - List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, SESSION); + List addFileEntries = transactionLogAccess.getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, SESSION); assertEquals(addFileEntries.size(), 12); }, ImmutableMultiset.builder() - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 4) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? - .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_GET_LENGTH), 2) // TODO (https://github.com/trinodb/trino/issues/16775) why not e.g. once? + .add(new FileOperation("00000000000000000010.checkpoint.parquet", INPUT_FILE_NEW_STREAM)) .build()); } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestingDeltaLakeUtils.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestingDeltaLakeUtils.java index 18468f184be3..41d04cb1f002 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestingDeltaLakeUtils.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestingDeltaLakeUtils.java @@ -14,6 +14,8 @@ package io.trino.plugin.deltalake; import io.trino.plugin.deltalake.transactionlog.AddFileEntry; +import io.trino.plugin.deltalake.transactionlog.MetadataEntry; +import io.trino.plugin.deltalake.transactionlog.ProtocolEntry; import io.trino.plugin.deltalake.transactionlog.TableSnapshot; import io.trino.plugin.deltalake.transactionlog.TransactionLogAccess; import io.trino.spi.connector.SchemaTableName; @@ -48,7 +50,9 @@ public static List getTableActiveFiles(TransactionLogAccess transa transactionLogAccess.flushCache(); TableSnapshot snapshot = transactionLogAccess.getSnapshot(SESSION, dummyTable, tableLocation, Optional.empty()); - List activeFiles = transactionLogAccess.getActiveFiles(snapshot, SESSION); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(snapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, snapshot); + List activeFiles = transactionLogAccess.getActiveFiles(snapshot, metadataEntry, protocolEntry, SESSION); transactionLogAccess.cleanupQuery(SESSION); return activeFiles; } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/TestTableSnapshot.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/TestTableSnapshot.java index af5e8a6aa508..3b6186fe7654 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/TestTableSnapshot.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/TestTableSnapshot.java @@ -14,8 +14,6 @@ package io.trino.plugin.deltalake.transactionlog; import com.google.common.collect.HashMultiset; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultiset; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multiset; @@ -24,11 +22,14 @@ import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.hdfs.HdfsFileSystemFactory; import io.trino.parquet.ParquetReaderOptions; +import io.trino.plugin.deltalake.DeltaLakeConfig; import io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointSchemaManager; import io.trino.plugin.deltalake.transactionlog.checkpoint.LastCheckpoint; import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.type.TypeManager; +import io.trino.testing.TestingConnectorContext; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -43,6 +44,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.filesystem.TrackingFileSystemFactory.OperationType.INPUT_FILE_NEW_STREAM; +import static io.trino.plugin.deltalake.transactionlog.TableSnapshot.MetadataAndProtocolEntry; +import static io.trino.plugin.deltalake.transactionlog.TableSnapshot.load; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.readLastCheckpoint; import static io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator.EntryType.ADD; import static io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointEntryIterator.EntryType.PROTOCOL; @@ -88,7 +91,7 @@ public void testOnlyReadsTrailingJsonFiles() assertFileSystemAccesses( () -> { Optional lastCheckpoint = readLastCheckpoint(trackingFileSystem, tableLocation); - tableSnapshot.set(TableSnapshot.load( + tableSnapshot.set(load( new SchemaTableName("schema", "person"), lastCheckpoint, trackingFileSystem, @@ -118,7 +121,7 @@ public void readsCheckpointFile() throws IOException { Optional lastCheckpoint = readLastCheckpoint(trackingFileSystem, tableLocation); - TableSnapshot tableSnapshot = TableSnapshot.load( + TableSnapshot tableSnapshot = load( new SchemaTableName("schema", "person"), lastCheckpoint, trackingFileSystem, @@ -126,9 +129,20 @@ public void readsCheckpointFile() parquetReaderOptions, true, domainCompactionThreshold); - tableSnapshot.setCachedMetadata(Optional.of(new MetadataEntry("id", "name", "description", null, "schema", ImmutableList.of(), ImmutableMap.of(), 0))); + TestingConnectorContext context = new TestingConnectorContext(); + TypeManager typeManager = context.getTypeManager(); + TransactionLogAccess transactionLogAccess = new TransactionLogAccess( + typeManager, + new CheckpointSchemaManager(typeManager), + new DeltaLakeConfig(), + new FileFormatDataSourceStats(), + trackingFileSystemFactory, + new ParquetReaderConfig()); + MetadataEntry metadataEntry = transactionLogAccess.getMetadataEntry(tableSnapshot, SESSION); + ProtocolEntry protocolEntry = transactionLogAccess.getProtocolEntry(SESSION, tableSnapshot); + tableSnapshot.setCachedMetadata(Optional.of(metadataEntry)); try (Stream stream = tableSnapshot.getCheckpointTransactionLogEntries( - SESSION, ImmutableSet.of(ADD), checkpointSchemaManager, TESTING_TYPE_MANAGER, trackingFileSystem, new FileFormatDataSourceStats())) { + SESSION, ImmutableSet.of(ADD), checkpointSchemaManager, TESTING_TYPE_MANAGER, trackingFileSystem, new FileFormatDataSourceStats(), Optional.of(new MetadataAndProtocolEntry(metadataEntry, protocolEntry)))) { List entries = stream.collect(toImmutableList()); assertThat(entries).hasSize(9); @@ -170,7 +184,7 @@ public void readsCheckpointFile() // lets read two entry types in one call; add and protocol try (Stream stream = tableSnapshot.getCheckpointTransactionLogEntries( - SESSION, ImmutableSet.of(ADD, PROTOCOL), checkpointSchemaManager, TESTING_TYPE_MANAGER, trackingFileSystem, new FileFormatDataSourceStats())) { + SESSION, ImmutableSet.of(ADD, PROTOCOL), checkpointSchemaManager, TESTING_TYPE_MANAGER, trackingFileSystem, new FileFormatDataSourceStats(), Optional.of(new MetadataAndProtocolEntry(metadataEntry, protocolEntry)))) { List entries = stream.collect(toImmutableList()); assertThat(entries).hasSize(10); @@ -218,7 +232,7 @@ public void testMaxTransactionId() throws IOException { Optional lastCheckpoint = readLastCheckpoint(trackingFileSystem, tableLocation); - TableSnapshot tableSnapshot = TableSnapshot.load( + TableSnapshot tableSnapshot = load( new SchemaTableName("schema", "person"), lastCheckpoint, trackingFileSystem, diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java index 0dbe85242b3a..5cd49f81ac57 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestCheckpointEntryIterator.java @@ -61,6 +61,7 @@ import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_STATS; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @TestInstance(PER_CLASS) @@ -82,6 +83,16 @@ public void tearDown() checkpointSchemaManager = null; } + @Test + public void testReadNoEntries() + throws Exception + { + URI checkpointUri = getResource(TEST_CHECKPOINT).toURI(); + assertThatThrownBy(() -> createCheckpointEntryIterator(checkpointUri, ImmutableSet.of(), Optional.empty(), Optional.empty())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("fields is empty"); + } + @Test public void testReadMetadataEntry() throws Exception @@ -134,6 +145,49 @@ public void testReadProtocolEntries() Optional.empty())); } + @Test + public void testReadMetadataAndProtocolEntry() + throws Exception + { + URI checkpointUri = getResource(TEST_CHECKPOINT).toURI(); + CheckpointEntryIterator checkpointEntryIterator = createCheckpointEntryIterator(checkpointUri, ImmutableSet.of(METADATA, PROTOCOL), Optional.empty(), Optional.empty()); + List entries = ImmutableList.copyOf(checkpointEntryIterator); + + assertThat(entries).hasSize(2); + assertThat(entries).containsExactlyInAnyOrder( + DeltaLakeTransactionLogEntry.metadataEntry(new MetadataEntry( + "b6aeffad-da73-4dde-b68e-937e468b1fde", + null, + null, + new MetadataEntry.Format("parquet", Map.of()), + "{\"type\":\"struct\",\"fields\":[" + + "{\"name\":\"name\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"age\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"married\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}}," + + + "{\"name\":\"phones\",\"type\":{\"type\":\"array\",\"elementType\":{\"type\":\"struct\",\"fields\":[" + + "{\"name\":\"number\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"label\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}," + + "\"containsNull\":true},\"nullable\":true,\"metadata\":{}}," + + + "{\"name\":\"address\",\"type\":{\"type\":\"struct\",\"fields\":[" + + "{\"name\":\"street\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"city\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"state\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," + + "{\"name\":\"zip\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]},\"nullable\":true,\"metadata\":{}}," + + + "{\"name\":\"income\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}", + List.of("age"), + Map.of(), + 1579190100722L)), + DeltaLakeTransactionLogEntry.protocolEntry( + new ProtocolEntry( + 1, + 2, + Optional.empty(), + Optional.empty()))); + } + @Test public void testReadAddEntries() throws Exception @@ -309,6 +363,8 @@ public void testSkipRemoveEntries() targetFile.delete(); // file must not exist when writer is called writer.write(entries, createOutputFile(targetPath)); + CheckpointEntryIterator metadataAndProtocolEntryIterator = + createCheckpointEntryIterator(URI.create(targetPath), ImmutableSet.of(METADATA, PROTOCOL), Optional.empty(), Optional.empty()); CheckpointEntryIterator addEntryIterator = createCheckpointEntryIterator( URI.create(targetPath), ImmutableSet.of(ADD), @@ -319,10 +375,12 @@ public void testSkipRemoveEntries() CheckpointEntryIterator txnEntryIterator = createCheckpointEntryIterator(URI.create(targetPath), ImmutableSet.of(TRANSACTION), Optional.empty(), Optional.empty()); + assertThat(Iterators.size(metadataAndProtocolEntryIterator)).isEqualTo(2); assertThat(Iterators.size(addEntryIterator)).isEqualTo(1); assertThat(Iterators.size(removeEntryIterator)).isEqualTo(numRemoveEntries); assertThat(Iterators.size(txnEntryIterator)).isEqualTo(0); + assertThat(metadataAndProtocolEntryIterator.getCompletedPositions().orElseThrow()).isEqualTo(3L); assertThat(addEntryIterator.getCompletedPositions().orElseThrow()).isEqualTo(2L); assertThat(removeEntryIterator.getCompletedPositions().orElseThrow()).isEqualTo(100L); assertThat(txnEntryIterator.getCompletedPositions().orElseThrow()).isEqualTo(0L); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestTransactionLogTail.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestTransactionLogTail.java index 211de2c1ceb1..7d005a9928b8 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestTransactionLogTail.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/transactionlog/checkpoint/TestTransactionLogTail.java @@ -16,8 +16,7 @@ import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.hdfs.HdfsFileSystemFactory; import io.trino.plugin.deltalake.transactionlog.DeltaLakeTransactionLogEntry; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.Optional; @@ -31,22 +30,20 @@ public class TestTransactionLogTail { - @Test(dataProvider = "dataSource") - public void testTail(String dataSource) + @Test + public void testTail() throws Exception { - String tableLocation = getClass().getClassLoader().getResource(format("%s/person", dataSource)).toURI().toString(); - assertEquals(readJsonTransactionLogTails(tableLocation).size(), 7); - assertEquals(updateJsonTransactionLogTails(tableLocation).size(), 7); + testTail("databricks73"); + testTail("deltalake"); } - @DataProvider - public Object[][] dataSource() + private void testTail(String dataSource) + throws Exception { - return new Object[][] { - {"databricks73"}, - {"deltalake"} - }; + String tableLocation = getClass().getClassLoader().getResource(format("%s/person", dataSource)).toURI().toString(); + assertEquals(readJsonTransactionLogTails(tableLocation).size(), 7); + assertEquals(updateJsonTransactionLogTails(tableLocation).size(), 7); } private List updateJsonTransactionLogTails(String tableLocation) diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/README.md b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/README.md new file mode 100644 index 000000000000..99d540dbca44 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/README.md @@ -0,0 +1,28 @@ +Data generated using Apache Spark 3.4.0 & Delta Lake OSS 2.4.0. + +This test resource is used to verify whether the reading from Delta Lake tables with +multi-part checkpoint files works as expected. + +Trino +``` +CREATE TABLE multipartcheckpoint(c integer) with (checkpoint_interval = 6); +``` + +From https://docs.delta.io/latest/optimizations-oss.html + +> In Delta Lake, by default each checkpoint is written as a single Parquet file. To to use this feature, +> set the SQL configuration ``spark.databricks.delta.checkpoint.partSize=``, where n is the limit of +> number of actions (such as `AddFile`) at which Delta Lake on Apache Spark will start parallelizing the +> checkpoint and attempt to write a maximum of this many actions per checkpoint file. + +Spark +``` +SET spark.databricks.delta.checkpoint.partSize=3; +INSERT INTO multipartcheckpoint values 1; +INSERT INTO multipartcheckpoint values 2; +INSERT INTO multipartcheckpoint values 3; +INSERT INTO multipartcheckpoint values 4; +INSERT INTO multipartcheckpoint values 5; +INSERT INTO multipartcheckpoint values 6; +INSERT INTO multipartcheckpoint values 7; +``` diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000000.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000000.json new file mode 100644 index 000000000000..ba5929dec80d --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000000.json @@ -0,0 +1,3 @@ +{"commitInfo":{"version":0,"timestamp":1697439143958,"userId":"marius","userName":"marius","operation":"CREATE TABLE","operationParameters":{"queryId":"20231016_065223_00001_dhwpa"},"clusterId":"trino-428-191-g91ee252-presto-master","readVersion":0,"isolationLevel":"WriteSerializable","isBlindAppend":true}} +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"metaData":{"id":"ce0eab6c-75a5-4904-9f90-2fe73bedf1ce","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"c\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{"delta.checkpointInterval":"6"},"createdTime":1697439143958}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000001.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000001.json new file mode 100644 index 000000000000..6f728af0f9e0 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000001.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439172229,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"ff40545b-ceb7-4836-8b35-a2147cf21677"}} +{"add":{"path":"part-00000-81f5948d-6e60-4582-b758-2424da5aef0f-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439172000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":1},\"maxValues\":{\"c\":1},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000002.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000002.json new file mode 100644 index 000000000000..72ec0c619113 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000002.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439178642,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":1,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"e730895b-9e56-4738-90fb-ce62ab08f3b1"}} +{"add":{"path":"part-00000-4a4be67a-1d35-4c93-ac3d-c1b26e2c5f3a-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439178000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":2},\"maxValues\":{\"c\":2},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000003.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000003.json new file mode 100644 index 000000000000..5ce98ef913c7 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000003.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439181640,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":2,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"0ee530c2-daef-4e2e-b2ad-de64e3e7b940"}} +{"add":{"path":"part-00000-aeede0e4-d6fd-4425-ab3e-be80963fe765-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439181000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":3},\"maxValues\":{\"c\":3},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000004.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000004.json new file mode 100644 index 000000000000..8a079b24cc25 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000004.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439185136,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":3,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"3b236730-8187-4d5d-9c2e-ecb281788a15"}} +{"add":{"path":"part-00000-30d20302-2223-4370-b4ec-9129845af85d-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439185000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":4},\"maxValues\":{\"c\":4},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000005.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000005.json new file mode 100644 index 000000000000..92ad5eba1c9a --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000005.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439189907,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":4,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"1df6e97f-ab4c-4f6d-ac90-48f2e8d0086b"}} +{"add":{"path":"part-00000-68d36678-1302-4d91-a23d-1049d2630b60-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439189000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":5},\"maxValues\":{\"c\":5},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.checkpoint.0000000001.0000000002.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.checkpoint.0000000001.0000000002.parquet new file mode 100644 index 000000000000..5a8652e15f98 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.checkpoint.0000000001.0000000002.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.checkpoint.0000000002.0000000002.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.checkpoint.0000000002.0000000002.parquet new file mode 100644 index 000000000000..fc88d59544d6 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.checkpoint.0000000002.0000000002.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.json new file mode 100644 index 000000000000..94cd9a799777 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000006.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439194248,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":5,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"c757b395-39ce-4007-871f-b648423ec886"}} +{"add":{"path":"part-00000-c2ea82b3-cbf3-413a-ade1-409a1ce69d9c-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439194000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":6},\"maxValues\":{\"c\":6},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000007.json b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000007.json new file mode 100644 index 000000000000..60e3daf0a5a0 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/00000000000000000007.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1697439206526,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":6,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"449"},"engineInfo":"Apache-Spark/3.4.0 Delta-Lake/2.4.0","txnId":"692e911c-78a9-4e97-9fdd-5e2bf33c7a2a"}} +{"add":{"path":"part-00000-07432bff-a65c-4b96-8775-074e4e99e771-c000.snappy.parquet","partitionValues":{},"size":449,"modificationTime":1697439206000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c\":7},\"maxValues\":{\"c\":7},\"nullCount\":{\"c\":0}}"}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/_last_checkpoint b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/_last_checkpoint new file mode 100644 index 000000000000..e5d513c4df3a --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/_delta_log/_last_checkpoint @@ -0,0 +1 @@ +{"version":6,"size":8,"parts":2,"sizeInBytes":27011,"numOfAddFiles":6,"checkpointSchema":{"type":"struct","fields":[{"name":"txn","type":{"type":"struct","fields":[{"name":"appId","type":"string","nullable":true,"metadata":{}},{"name":"version","type":"long","nullable":true,"metadata":{}},{"name":"lastUpdated","type":"long","nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}},{"name":"add","type":{"type":"struct","fields":[{"name":"path","type":"string","nullable":true,"metadata":{}},{"name":"partitionValues","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}},{"name":"size","type":"long","nullable":true,"metadata":{}},{"name":"modificationTime","type":"long","nullable":true,"metadata":{}},{"name":"dataChange","type":"boolean","nullable":true,"metadata":{}},{"name":"tags","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}},{"name":"deletionVector","type":{"type":"struct","fields":[{"name":"storageType","type":"string","nullable":true,"metadata":{}},{"name":"pathOrInlineDv","type":"string","nullable":true,"metadata":{}},{"name":"offset","type":"integer","nullable":true,"metadata":{}},{"name":"sizeInBytes","type":"integer","nullable":true,"metadata":{}},{"name":"cardinality","type":"long","nullable":true,"metadata":{}},{"name":"maxRowIndex","type":"long","nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}},{"name":"stats","type":"string","nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}},{"name":"remove","type":{"type":"struct","fields":[{"name":"path","type":"string","nullable":true,"metadata":{}},{"name":"deletionTimestamp","type":"long","nullable":true,"metadata":{}},{"name":"dataChange","type":"boolean","nullable":true,"metadata":{}},{"name":"extendedFileMetadata","type":"boolean","nullable":true,"metadata":{}},{"name":"partitionValues","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}},{"name":"size","type":"long","nullable":true,"metadata":{}},{"name":"deletionVector","type":{"type":"struct","fields":[{"name":"storageType","type":"string","nullable":true,"metadata":{}},{"name":"pathOrInlineDv","type":"string","nullable":true,"metadata":{}},{"name":"offset","type":"integer","nullable":true,"metadata":{}},{"name":"sizeInBytes","type":"integer","nullable":true,"metadata":{}},{"name":"cardinality","type":"long","nullable":true,"metadata":{}},{"name":"maxRowIndex","type":"long","nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}},{"name":"metaData","type":{"type":"struct","fields":[{"name":"id","type":"string","nullable":true,"metadata":{}},{"name":"name","type":"string","nullable":true,"metadata":{}},{"name":"description","type":"string","nullable":true,"metadata":{}},{"name":"format","type":{"type":"struct","fields":[{"name":"provider","type":"string","nullable":true,"metadata":{}},{"name":"options","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}},{"name":"schemaString","type":"string","nullable":true,"metadata":{}},{"name":"partitionColumns","type":{"type":"array","elementType":"string","containsNull":true},"nullable":true,"metadata":{}},{"name":"configuration","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}},{"name":"createdTime","type":"long","nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}},{"name":"protocol","type":{"type":"struct","fields":[{"name":"minReaderVersion","type":"integer","nullable":true,"metadata":{}},{"name":"minWriterVersion","type":"integer","nullable":true,"metadata":{}},{"name":"readerFeatures","type":{"type":"array","elementType":"string","containsNull":true},"nullable":true,"metadata":{}},{"name":"writerFeatures","type":{"type":"array","elementType":"string","containsNull":true},"nullable":true,"metadata":{}}]},"nullable":true,"metadata":{}}]},"checksum":"e3aeff08e804e2c1d2d8367707f7efca"} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-07432bff-a65c-4b96-8775-074e4e99e771-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-07432bff-a65c-4b96-8775-074e4e99e771-c000.snappy.parquet new file mode 100644 index 000000000000..8cd919e54929 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-07432bff-a65c-4b96-8775-074e4e99e771-c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-30d20302-2223-4370-b4ec-9129845af85d-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-30d20302-2223-4370-b4ec-9129845af85d-c000.snappy.parquet new file mode 100644 index 000000000000..6bbd485721bf Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-30d20302-2223-4370-b4ec-9129845af85d-c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-4a4be67a-1d35-4c93-ac3d-c1b26e2c5f3a-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-4a4be67a-1d35-4c93-ac3d-c1b26e2c5f3a-c000.snappy.parquet new file mode 100644 index 000000000000..07e176899a7b Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-4a4be67a-1d35-4c93-ac3d-c1b26e2c5f3a-c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-68d36678-1302-4d91-a23d-1049d2630b60-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-68d36678-1302-4d91-a23d-1049d2630b60-c000.snappy.parquet new file mode 100644 index 000000000000..08abfd031ef1 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-68d36678-1302-4d91-a23d-1049d2630b60-c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-81f5948d-6e60-4582-b758-2424da5aef0f-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-81f5948d-6e60-4582-b758-2424da5aef0f-c000.snappy.parquet new file mode 100644 index 000000000000..60fa1ce9a490 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-81f5948d-6e60-4582-b758-2424da5aef0f-c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-aeede0e4-d6fd-4425-ab3e-be80963fe765-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-aeede0e4-d6fd-4425-ab3e-be80963fe765-c000.snappy.parquet new file mode 100644 index 000000000000..a4045cfbcc4c Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-aeede0e4-d6fd-4425-ab3e-be80963fe765-c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-c2ea82b3-cbf3-413a-ade1-409a1ce69d9c-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-c2ea82b3-cbf3-413a-ade1-409a1ce69d9c-c000.snappy.parquet new file mode 100644 index 000000000000..92a77908f855 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/multipart_checkpoint/part-00000-c2ea82b3-cbf3-413a-ade1-409a1ce69d9c-c000.snappy.parquet differ diff --git a/plugin/trino-druid/pom.xml b/plugin/trino-druid/pom.xml index b2223865718a..3a8757ed05e4 100644 --- a/plugin/trino-druid/pom.xml +++ b/plugin/trino-druid/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMapping.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMapping.java index 1ec43cec97ad..90d991607b0d 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMapping.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMapping.java @@ -21,9 +21,10 @@ import io.trino.testing.MaterializedResult; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.SkipException; -import org.testng.annotations.AfterClass; -import org.testng.annotations.Test; +import io.trino.testng.services.ManageTestResources; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.nio.file.Path; import java.util.List; @@ -40,15 +41,18 @@ import static io.trino.tpch.TpchTable.REGION; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.abort; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) public class TestDruidCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { + @ManageTestResources.Suppress(because = "Not a TestNG test class") private TestingDruidServer druidServer; private Path mappingFile; - @AfterClass(alwaysRun = true) + @AfterAll public void destroy() { if (druidServer != null) { @@ -114,6 +118,8 @@ public void testNonLowerCaseTableName() public void testTableNameClash() throws Exception { + updateRuleBasedIdentifierMappingFile(getMappingFile(), ImmutableList.of(), ImmutableList.of()); + copyAndIngestTpchDataFromSourceToTarget( getQueryRunner().execute(SELECT_FROM_REGION), this.druidServer, @@ -191,7 +197,7 @@ public void testTableNameClashWithRuleMapping() public void testNonLowerCaseSchemaName() { // related to https://github.com/trinodb/trino/issues/14700 - throw new SkipException("Druid connector only supports schema 'druid'."); + abort("Druid connector only supports schema 'druid'."); } @Override @@ -199,7 +205,7 @@ public void testNonLowerCaseSchemaName() public void testSchemaAndTableNameRuleMapping() { // related to https://github.com/trinodb/trino/issues/14700 - throw new SkipException("Druid connector only supports schema 'druid'."); + abort("Druid connector only supports schema 'druid'."); } @Override @@ -207,7 +213,7 @@ public void testSchemaAndTableNameRuleMapping() public void testSchemaNameClash() { // related to https://github.com/trinodb/trino/issues/14700 - throw new SkipException("Druid connector only supports schema 'druid'."); + abort("Druid connector only supports schema 'druid'."); } @Override @@ -215,7 +221,7 @@ public void testSchemaNameClash() public void testSchemaNameClashWithRuleMapping() { // related to https://github.com/trinodb/trino/issues/14700 - throw new SkipException("Druid connector only supports schema 'druid'."); + abort("Druid connector only supports schema 'druid'."); } @Override @@ -223,6 +229,6 @@ public void testSchemaNameClashWithRuleMapping() public void testSchemaNameRuleMapping() { // related to https://github.com/trinodb/trino/issues/14700 - throw new SkipException("Druid connector only supports schema 'druid'."); + abort("Druid connector only supports schema 'druid'."); } } diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java index e9bf979ef3a4..c4c2ae91868b 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java @@ -33,7 +33,6 @@ import org.intellij.lang.annotations.Language; import org.testng.SkipException; import org.testng.annotations.AfterClass; -import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.sql.DatabaseMetaData; @@ -528,8 +527,21 @@ public void testPredicatePushdownForTimestampWithMillisPrecision() .isNotFullyPushedDown(FilterNode.class); } - @Test(dataProvider = "timestampValuesProvider") - public void testPredicatePushdownForTimestampWithHigherPrecision(String timestamp) + @Test + public void testPredicatePushdownForTimestampWithHigherPrecision() + { + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.1234"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.12345"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.123456"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.1234567"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.12345678"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.123456789"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.1234567891"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.12345678912"); + testPredicatePushdownForTimestampWithHigherPrecision("1992-01-04 00:00:00.123456789123"); + } + + private void testPredicatePushdownForTimestampWithHigherPrecision(String timestamp) { // timestamp equality assertThat(query(format("SELECT linenumber, partkey, shipmode FROM lineitem WHERE __time = TIMESTAMP '%s'", timestamp))) @@ -568,20 +580,4 @@ public void testPredicatePushdownForTimestampWithHigherPrecision(String timestam "(BIGINT '1', BIGINT '574', CAST('AIR' AS varchar))") .isNotFullyPushedDown(FilterNode.class); } - - @DataProvider - public Object[][] timestampValuesProvider() - { - return new Object[][] { - {"1992-01-04 00:00:00.1234"}, - {"1992-01-04 00:00:00.12345"}, - {"1992-01-04 00:00:00.123456"}, - {"1992-01-04 00:00:00.1234567"}, - {"1992-01-04 00:00:00.12345678"}, - {"1992-01-04 00:00:00.123456789"}, - {"1992-01-04 00:00:00.1234567891"}, - {"1992-01-04 00:00:00.12345678912"}, - {"1992-01-04 00:00:00.123456789123"} - }; - } } diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidTypeMapping.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidTypeMapping.java index acb92157e206..3be97ff70345 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidTypeMapping.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidTypeMapping.java @@ -21,8 +21,9 @@ import io.trino.testing.QueryRunner; import io.trino.testing.datatype.DataSetup; import io.trino.testing.datatype.SqlDataTypeTest; -import org.testng.annotations.AfterClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.io.BufferedWriter; import java.io.FileWriter; @@ -40,7 +41,9 @@ import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestDruidTypeMapping extends AbstractTestQueryFramework { @@ -55,7 +58,7 @@ protected QueryRunner createQueryRunner() return createDruidQueryRunnerTpch(druidServer, ImmutableMap.of(), ImmutableMap.of(), ImmutableList.of()); } - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() { druidServer.close(); diff --git a/plugin/trino-elasticsearch/pom.xml b/plugin/trino-elasticsearch/pom.xml index 1152e7e6534f..0b4c981bcadb 100644 --- a/plugin/trino-elasticsearch/pom.xml +++ b/plugin/trino-elasticsearch/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-example-http/pom.xml b/plugin/trino-example-http/pom.xml index 0030df0efe2f..7bd889978dfd 100644 --- a/plugin/trino-example-http/pom.xml +++ b/plugin/trino-example-http/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-example-jdbc/pom.xml b/plugin/trino-example-jdbc/pom.xml index 409616c11411..21b58cfc641f 100644 --- a/plugin/trino-example-jdbc/pom.xml +++ b/plugin/trino-example-jdbc/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-exchange-filesystem/pom.xml b/plugin/trino-exchange-filesystem/pom.xml index a277e4ab2d67..1a5167cd9377 100644 --- a/plugin/trino-exchange-filesystem/pom.xml +++ b/plugin/trino-exchange-filesystem/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-exchange-hdfs/pom.xml b/plugin/trino-exchange-hdfs/pom.xml index d85173a264f7..4c6aca8f292f 100644 --- a/plugin/trino-exchange-hdfs/pom.xml +++ b/plugin/trino-exchange-hdfs/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-geospatial/pom.xml b/plugin/trino-geospatial/pom.xml index 2dcb99960571..017c36af4297 100644 --- a/plugin/trino-geospatial/pom.xml +++ b/plugin/trino-geospatial/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java index 659ce9d6585f..8f9d5e8b2453 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java @@ -24,7 +24,7 @@ public class BingTileType public static final BingTileType BING_TILE = new BingTileType(); public static final String NAME = "BingTile"; - public BingTileType() + private BingTileType() { super(new TypeSignature(NAME)); } @@ -42,6 +42,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return BingTile.decode(block.getLong(position, 0)); + return BingTile.decode(getLong(block, position)); } } diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java index 510c28dc2dae..9904fbd0006f 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -42,7 +43,9 @@ protected GeometryType(TypeSignature signature) @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -71,7 +74,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); - return deserialize(slice).asText(); + return deserialize(getSlice(block, position)).asText(); } } diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java index e9daa2ebf90a..f66e3afd1863 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java @@ -19,6 +19,7 @@ import io.trino.geospatial.KdbTreeUtils; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; @@ -54,7 +55,7 @@ private KdbTreeType() { // The KDB tree type should be KdbTree but can not be since KdbTree is in // both the plugin class loader and the system class loader. This was done - // so the plan optimizer can process geo spatial joins. + // so the plan optimizer can process geospatial joins. super(new TypeSignature(NAME), Object.class); } @@ -83,9 +84,10 @@ public Object getObject(Block block, int position) if (block.isNull(position)) { return null; } - Slice bytes = block.getSlice(position, 0, block.getSliceLength(position)); - KdbTree kdbTree = KdbTreeUtils.fromJson(bytes.toStringUtf8()); - return kdbTree; + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String json = valueBlock.getSlice(valuePosition).toStringUtf8(); + return KdbTreeUtils.fromJson(json); } @Override @@ -149,14 +151,16 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockToFlat( - @BlockPosition Block block, + @BlockPosition VariableWidthBlock block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) { - Slice bytes = block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice bytes = valueBlock.getSlice(valuePosition); bytes.getBytes(0, variableSizeSlice, variableSizeOffset, bytes.length()); INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, bytes.length()); diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java index 3dfcff0c3edc..af01848db877 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -37,7 +38,9 @@ private SphericalGeographyType() @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -58,7 +61,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); - return deserialize(slice).asText(); + return deserialize(getSlice(block, position)).asText(); } } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java index f800c5f228c7..a233e1f26922 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java @@ -16,8 +16,8 @@ import io.trino.geospatial.KdbTree; import io.trino.geospatial.KdbTree.Node; import io.trino.geospatial.Rectangle; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.type.AbstractTestType; import org.junit.jupiter.api.Test; @@ -36,7 +36,7 @@ protected TestKdbTreeType() super(KDB_TREE, KdbTree.class, createTestBlock()); } - private static Block createTestBlock() + private static ValueBlock createTestBlock() { BlockBuilder blockBuilder = KDB_TREE.createBlockBuilder(null, 1); KdbTree kdbTree = new KdbTree( @@ -46,7 +46,7 @@ private static Block createTestBlock() Optional.empty(), Optional.empty())); KDB_TREE.writeObject(blockBuilder, kdbTree); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java index 8c6b70b0afd5..8207cd03895e 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java @@ -30,6 +30,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.testing.LocalQueryRunner; import org.junit.jupiter.api.Test; @@ -71,7 +72,9 @@ public void test(int partitionCount) List geometries = makeGeometries(); Block geometryBlock = makeGeometryBlock(geometries); - Block partitionCountBlock = BlockAssertions.createRepeatedValuesBlock(partitionCount, geometries.size()); + BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, 1); + INTEGER.writeInt(blockBuilder, partitionCount); + Block partitionCountBlock = RunLengthEncodedBlock.create(blockBuilder.build(), geometries.size()); Rectangle expectedExtent = new Rectangle(-10, -10, Math.nextUp(10.0), Math.nextUp(10.0)); String expectedValue = getSpatialPartitioning(expectedExtent, geometries, partitionCount); diff --git a/plugin/trino-google-sheets/pom.xml b/plugin/trino-google-sheets/pom.xml index aa4d9287ffdf..4b40dea2d8b1 100644 --- a/plugin/trino-google-sheets/pom.xml +++ b/plugin/trino-google-sheets/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-hive-hadoop2/pom.xml b/plugin/trino-hive-hadoop2/pom.xml index e6941c6de1c6..e9b4f2d243dc 100644 --- a/plugin/trino-hive-hadoop2/pom.xml +++ b/plugin/trino-hive-hadoop2/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -142,6 +142,12 @@ runtime + + io.airlift + junit-extensions + test + + io.airlift testing @@ -198,6 +204,12 @@ test + + org.junit.jupiter + junit-jupiter-api + test + + org.testng testng diff --git a/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/TestHive.java b/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/TestHive.java index f1d0a2a366a9..5229758a4297 100644 --- a/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/TestHive.java +++ b/plugin/trino-hive-hadoop2/src/test/java/io/trino/plugin/hive/TestHive.java @@ -19,23 +19,25 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import org.apache.hadoop.net.NetUtils; -import org.testng.SkipException; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assumptions.abort; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; // staging directory is shared mutable state -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) public class TestHive extends AbstractTestHive { - @Parameters({"test.metastore", "test.database"}) - @BeforeClass - public void initialize(String metastore, String database) + @BeforeAll + public void initialize() { + String metastore = System.getProperty("test.metastore"); + String database = System.getProperty("test.database"); String hadoopMasterIp = System.getProperty("hadoop-master-ip"); if (hadoopMasterIp != null) { // Even though Hadoop is accessed by proxy, Hadoop still tries to resolve hadoop-master @@ -55,6 +57,7 @@ public void forceTestNgToRespectSingleThreaded() // https://github.com/cbeust/testng/issues/2361#issuecomment-688393166 a workaround it to add a dummy test to the leaf test class. } + @Test @Override public void testHideDeltaLakeTables() { @@ -66,7 +69,7 @@ public void testHideDeltaLakeTables() " \\[\\1]\n" + "but found.*"); - throw new SkipException("not supported"); + abort("not supported"); } @Test @@ -91,6 +94,7 @@ public void testHiveViewTranslationError() } } + @Test @Override public void testUpdateBasicPartitionStatistics() throws Exception @@ -112,6 +116,7 @@ public void testUpdateBasicPartitionStatistics() } } + @Test @Override public void testUpdatePartitionColumnStatistics() throws Exception @@ -133,6 +138,7 @@ public void testUpdatePartitionColumnStatistics() } } + @Test @Override public void testUpdatePartitionColumnStatisticsEmptyOptionalFields() throws Exception @@ -154,6 +160,7 @@ public void testUpdatePartitionColumnStatisticsEmptyOptionalFields() } } + @Test @Override public void testStorePartitionWithStatistics() throws Exception @@ -164,6 +171,7 @@ public void testStorePartitionWithStatistics() testStorePartitionWithStatistics(STATISTICS_PARTITIONED_TABLE_COLUMNS, STATISTICS_1, STATISTICS_2, STATISTICS_1_1, EMPTY_ROWCOUNT_STATISTICS); } + @Test @Override public void testDataColumnProperties() { @@ -173,6 +181,7 @@ public void testDataColumnProperties() .hasMessage("Persisting column properties is not supported: Column{name=id, type=bigint}"); } + @Test @Override public void testPartitionColumnProperties() { diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index 8aca78e16ff1..f453179ff574 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java index a63435535669..301a43da59c2 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java @@ -228,8 +228,12 @@ public PartitionUpdateAndMergeResults getPartitionUpdateAndMergeResults(Partitio private Page buildDeletePage(Block rowIds, long writeId) { ColumnarRow columnarRow = toColumnarRow(rowIds); - checkArgument(!columnarRow.mayHaveNull(), "The rowIdsRowBlock may not have null rows"); int positionCount = rowIds.getPositionCount(); + if (columnarRow.mayHaveNull()) { + for (int position = 0; position < positionCount; position++) { + checkArgument(!columnarRow.isNull(position), "The rowIdsRowBlock may not have null rows"); + } + } // We've verified that the rowIds block has no null rows, so it's okay to get the field blocks Block[] blockArray = { RunLengthEncodedBlock.create(DELETE_OPERATION_BLOCK, positionCount), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java index 165fc86eb2a5..97585f0e3f09 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java @@ -128,7 +128,7 @@ public static Type createTypeFromCoercer(TypeManager typeManager, HiveType fromH return Optional.of(new IntegerNumberUpscaleCoercer<>(fromType, toType)); } if (fromHiveType.equals(HIVE_INT) && toHiveType.equals(HIVE_LONG)) { - return Optional.of(new IntegerNumberUpscaleCoercer<>(fromType, toType)); + return Optional.of(new IntegerToBigintCoercer()); } if (fromHiveType.equals(HIVE_FLOAT) && toHiveType.equals(HIVE_DOUBLE)) { return Optional.of(new FloatToDoubleCoercer()); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java index df8105941069..5a9ce09968e7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.type.DoubleType; import io.trino.spi.type.RealType; @@ -30,6 +31,16 @@ public FloatToDoubleCoercer() super(REAL, DOUBLE); } + @Override + public Block apply(Block block) + { + // data may have already been coerced by the Avro reader + if (block instanceof LongArrayBlock) { + return block; + } + return super.apply(block); + } + @Override protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/IntegerToBigintCoercer.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/IntegerToBigintCoercer.java new file mode 100644 index 000000000000..3cf4706b1855 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/IntegerToBigintCoercer.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.plugin.hive.coercions; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.IntegerType; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; + +public class IntegerToBigintCoercer + extends TypeCoercer +{ + public IntegerToBigintCoercer() + { + super(INTEGER, BIGINT); + } + + @Override + public Block apply(Block block) + { + // data may have already been coerced by the Avro reader + if (block instanceof LongArrayBlock) { + return block; + } + return super.apply(block); + } + + @Override + protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position) + { + BIGINT.writeLong(blockBuilder, INTEGER.getInt(block, position)); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index b21d0c1f3245..5dd96adc51d0 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -187,7 +187,7 @@ public Optional createPageSource( start, length, columns, - effectivePredicate, + ImmutableList.of(effectivePredicate), isUseParquetColumnNames(session), timeZone, stats, @@ -210,7 +210,7 @@ public static ReaderPageSource createPageSource( long start, long length, List columns, - TupleDomain effectivePredicate, + List> disjunctTupleDomains, boolean useColumnNames, DateTimeZone timeZone, FileFormatDataSourceStats stats, @@ -237,11 +237,23 @@ public static ReaderPageSource createPageSource( messageColumn = getColumnIO(fileSchema, requestedSchema); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, requestedSchema); - TupleDomain parquetTupleDomain = options.isIgnoreStatistics() - ? TupleDomain.all() - : getParquetTupleDomain(descriptorsByPath, effectivePredicate, fileSchema, useColumnNames); - - TupleDomainParquetPredicate parquetPredicate = buildPredicate(requestedSchema, parquetTupleDomain, descriptorsByPath, timeZone); + List> parquetTupleDomains; + List parquetPredicates; + if (options.isIgnoreStatistics()) { + parquetTupleDomains = ImmutableList.of(); + parquetPredicates = ImmutableList.of(); + } + else { + ImmutableList.Builder> parquetTupleDomainsBuilder = ImmutableList.builderWithExpectedSize(disjunctTupleDomains.size()); + ImmutableList.Builder parquetPredicatesBuilder = ImmutableList.builderWithExpectedSize(disjunctTupleDomains.size()); + for (TupleDomain tupleDomain : disjunctTupleDomains) { + TupleDomain parquetTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames); + parquetTupleDomainsBuilder.add(parquetTupleDomain); + parquetPredicatesBuilder.add(buildPredicate(requestedSchema, parquetTupleDomain, descriptorsByPath, timeZone)); + } + parquetTupleDomains = parquetTupleDomainsBuilder.build(); + parquetPredicates = parquetPredicatesBuilder.build(); + } long nextStart = 0; ImmutableList.Builder blocks = ImmutableList.builder(); @@ -249,23 +261,27 @@ public static ReaderPageSource createPageSource( ImmutableList.Builder> columnIndexes = ImmutableList.builder(); for (BlockMetaData block : parquetMetadata.getBlocks()) { long firstDataPage = block.getColumns().get(0).getFirstDataPageOffset(); - Optional columnIndex = getColumnIndexStore(dataSource, block, descriptorsByPath, parquetTupleDomain, options); - Optional bloomFilterStore = getBloomFilterStore(dataSource, block, parquetTupleDomain, options); - - if (start <= firstDataPage && firstDataPage < start + length - && predicateMatches( - parquetPredicate, - block, - dataSource, - descriptorsByPath, - parquetTupleDomain, - columnIndex, - bloomFilterStore, - timeZone, - domainCompactionThreshold)) { - blocks.add(block); - blockStarts.add(nextStart); - columnIndexes.add(columnIndex); + for (int i = 0; i < disjunctTupleDomains.size(); i++) { + TupleDomain parquetTupleDomain = parquetTupleDomains.get(i); + TupleDomainParquetPredicate parquetPredicate = parquetPredicates.get(i); + Optional columnIndex = getColumnIndexStore(dataSource, block, descriptorsByPath, parquetTupleDomain, options); + Optional bloomFilterStore = getBloomFilterStore(dataSource, block, parquetTupleDomain, options); + if (start <= firstDataPage && firstDataPage < start + length + && predicateMatches( + parquetPredicate, + block, + dataSource, + descriptorsByPath, + parquetTupleDomain, + columnIndex, + bloomFilterStore, + timeZone, + domainCompactionThreshold)) { + blocks.add(block); + blockStarts.add(nextStart); + columnIndexes.add(columnIndex); + break; + } } nextStart += block.getRowCount(); } @@ -289,7 +305,9 @@ && predicateMatches( memoryContext, options, exception -> handleException(dataSourceId, exception), - Optional.of(parquetPredicate), + // We avoid using disjuncts of parquetPredicate for page pruning in ParquetReader as currently column indexes + // are not present in the Parquet files which are read with disjunct predicates. + parquetPredicates.size() == 1 ? Optional.of(parquetPredicates.get(0)) : Optional.empty(), columnIndexes.build(), parquetWriteValidation); ConnectorPageSource parquetPageSource = createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java index f00d0f151a4c..a3e382ea2601 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java @@ -144,9 +144,11 @@ import org.apache.hadoop.hive.metastore.TableType; import org.assertj.core.api.InstanceOfAssertFactories; import org.joda.time.DateTime; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.io.IOException; import java.io.OutputStream; @@ -328,6 +330,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.joda.time.DateTimeZone.UTC; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; @@ -336,7 +339,7 @@ import static org.testng.Assert.fail; // staging directory is shared mutable state -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) public abstract class AbstractTestHive { private static final Logger log = Logger.get(AbstractTestHive.class); @@ -671,7 +674,7 @@ private static RowType toRowType(List columns) protected final Set materializedViews = Sets.newConcurrentHashSet(); - @BeforeClass(alwaysRun = true) + @BeforeAll public void setupClass() throws Exception { @@ -681,7 +684,7 @@ public void setupClass() temporaryStagingDirectory = createTempDirectory("trino-staging-"); } - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() { if (executor != null) { @@ -1645,12 +1648,14 @@ public void testPerTransactionDirectoryListerCache() } } - @Test(expectedExceptions = TableNotFoundException.class) + @Test public void testGetPartitionSplitsBatchInvalidTable() { - try (Transaction transaction = newTransaction()) { - getSplits(splitManager, transaction, newSession(), invalidTableHandle); - } + assertThatThrownBy(() -> { + try (Transaction transaction = newTransaction()) { + getSplits(splitManager, transaction, newSession(), invalidTableHandle); + } + }).isInstanceOf(TableNotFoundException.class); } @Test @@ -2390,21 +2395,25 @@ else if (rowNumber % 19 == 1) { } } - @Test(expectedExceptions = TrinoException.class, expectedExceptionsMessageRegExp = ".*The column 't_data' in table '.*\\.trino_test_partition_schema_change' is declared as type 'double', but partition 'ds=2012-12-29' declared column 't_data' as type 'string'.") + @Test public void testPartitionSchemaMismatch() - throws Exception { - try (Transaction transaction = newTransaction()) { - ConnectorMetadata metadata = transaction.getMetadata(); - ConnectorTableHandle table = getTableHandle(metadata, tablePartitionSchemaChange); - ConnectorSession session = newSession(); - metadata.beginQuery(session); - readTable(transaction, table, ImmutableList.of(dsColumn), session, TupleDomain.all(), OptionalInt.empty(), Optional.empty()); - } + assertThatThrownBy(() -> { + try (Transaction transaction = newTransaction()) { + ConnectorMetadata metadata = transaction.getMetadata(); + ConnectorTableHandle table = getTableHandle(metadata, tablePartitionSchemaChange); + ConnectorSession session = newSession(); + metadata.beginQuery(session); + readTable(transaction, table, ImmutableList.of(dsColumn), session, TupleDomain.all(), OptionalInt.empty(), Optional.empty()); + } + }) + .isInstanceOf(TrinoException.class) + .hasMessageMatching(".*The column 't_data' in table '.*\\.trino_test_partition_schema_change' is declared as type 'double', but partition 'ds=2012-12-29' declared column 't_data' as type 'string'."); } // TODO coercion of non-canonical values should be supported - @Test(enabled = false) + @Test + @Disabled public void testPartitionSchemaNonCanonical() throws Exception { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveLocal.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveLocal.java index a5835891020d..d00fe4350790 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveLocal.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveLocal.java @@ -36,10 +36,10 @@ import io.trino.testing.MaterializedResult; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.metastore.TableType; -import org.testng.SkipException; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import java.io.File; import java.io.IOException; @@ -69,8 +69,11 @@ import static io.trino.plugin.hive.util.HiveUtil.SPARK_TABLE_PROVIDER_KEY; import static java.nio.file.Files.copy; import static java.util.Objects.requireNonNull; +import static org.junit.jupiter.api.Assumptions.abort; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.testng.Assert.assertEquals; +@TestInstance(PER_CLASS) public abstract class AbstractTestHiveLocal extends AbstractTestHive { @@ -92,7 +95,7 @@ protected AbstractTestHiveLocal(String testDbName) protected abstract HiveMetastore createMetastore(File tempDir); - @BeforeClass(alwaysRun = true) + @BeforeAll public void initialize() throws Exception { @@ -170,7 +173,7 @@ protected void createTestTable(Table table) metastoreClient.createTable(table, NO_PRIVILEGES); } - @AfterClass(alwaysRun = true) + @AfterAll public void cleanup() throws IOException { @@ -191,31 +194,35 @@ protected ConnectorTableHandle getTableHandle(ConnectorMetadata metadata, Schema if (tableName.getTableName().startsWith(TEMPORARY_TABLE_PREFIX)) { return super.getTableHandle(metadata, tableName); } - throw new SkipException("tests using existing tables are not supported"); + return abort("tests using existing tables are not supported"); } + @Test @Override public void testGetAllTableColumns() { - throw new SkipException("Test disabled for this subclass"); + abort("Test disabled for this subclass"); } + @Test @Override public void testGetAllTableColumnsInSchema() { - throw new SkipException("Test disabled for this subclass"); + abort("Test disabled for this subclass"); } + @Test @Override public void testGetTableNames() { - throw new SkipException("Test disabled for this subclass"); + abort("Test disabled for this subclass"); } + @Test @Override public void testGetTableSchemaOffline() { - throw new SkipException("Test disabled for this subclass"); + abort("Test disabled for this subclass"); } @Test diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileMetastore.java index a93382fd5c4c..0616c5768d58 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileMetastore.java @@ -16,15 +16,14 @@ import io.trino.plugin.hive.metastore.HiveMetastore; import io.trino.plugin.hive.metastore.file.FileHiveMetastore; import io.trino.plugin.hive.metastore.file.FileHiveMetastoreConfig; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.File; import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_FACTORY; +import static org.junit.jupiter.api.Assumptions.abort; // staging directory is shared mutable state -@Test(singleThreaded = true) public class TestHiveFileMetastore extends AbstractTestHiveLocal { @@ -49,37 +48,43 @@ public void forceTestNgToRespectSingleThreaded() // https://github.com/cbeust/testng/issues/2361#issuecomment-688393166 a workaround it to add a dummy test to the leaf test class. } + @Test @Override public void testMismatchSchemaTable() { // FileHiveMetastore only supports replaceTable() for views } + @Test @Override public void testPartitionSchemaMismatch() { // test expects an exception to be thrown - throw new SkipException("FileHiveMetastore only supports replaceTable() for views"); + abort("FileHiveMetastore only supports replaceTable() for views"); } + @Test @Override public void testBucketedTableEvolution() { // FileHiveMetastore only supports replaceTable() for views } + @Test @Override public void testBucketedTableEvolutionWithDifferentReadBucketCount() { // FileHiveMetastore has various incompatibilities } + @Test @Override public void testTransactionDeleteInsert() { // FileHiveMetastore has various incompatibilities } + @Test @Override public void testInsertOverwriteUnpartitioned() { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveInMemoryMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveInMemoryMetastore.java index cb8b161be152..4513dd2ba7f6 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveInMemoryMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveInMemoryMetastore.java @@ -18,17 +18,16 @@ import io.trino.plugin.hive.metastore.thrift.BridgingHiveMetastore; import io.trino.plugin.hive.metastore.thrift.InMemoryThriftMetastore; import io.trino.plugin.hive.metastore.thrift.ThriftMetastoreConfig; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.File; import java.net.URI; import static java.nio.file.Files.createDirectories; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assumptions.abort; // staging directory is shared mutable state -@Test(singleThreaded = true) public class TestHiveInMemoryMetastore extends AbstractTestHiveLocal { @@ -57,30 +56,35 @@ public void forceTestNgToRespectSingleThreaded() // https://github.com/cbeust/testng/issues/2361#issuecomment-688393166 a workaround it to add a dummy test to the leaf test class. } + @Test @Override public void testMetadataDelete() { // InMemoryHiveMetastore ignores "removeData" flag in dropPartition } + @Test @Override public void testTransactionDeleteInsert() { // InMemoryHiveMetastore does not check whether partition exist in createPartition and dropPartition } + @Test @Override public void testHideDeltaLakeTables() { - throw new SkipException("not supported"); + abort("not supported"); } + @Test @Override public void testDisallowQueryingOfIcebergTables() { - throw new SkipException("not supported"); + abort("not supported"); } + @Test @Override public void testDataColumnProperties() { @@ -90,6 +94,7 @@ public void testDataColumnProperties() .hasMessage("Persisting column properties is not supported: Column{name=id, type=bigint}"); } + @Test @Override public void testPartitionColumnProperties() { @@ -98,4 +103,11 @@ public void testPartitionColumnProperties() .isInstanceOf(IllegalArgumentException.class) .hasMessage("Persisting column properties is not supported: Column{name=part_key, type=varchar(256)}"); } + + @Test + @Override + public void testPartitionSchemaMismatch() + { + abort("not supported"); + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java index 00dad7292dc1..19610bf24465 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java @@ -45,6 +45,7 @@ import static io.trino.spi.block.RowBlock.fromFieldBlocks; import static io.trino.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -220,11 +221,12 @@ private static Block createLongArrayBlock(List data) private static void verifyBlock(Block actualBlock, Type outputType, Block input, Type inputType, List dereferences) { - Block expectedOutputBlock = createProjectedColumnBlock(input, outputType, inputType, dereferences); + assertThat(inputType).isInstanceOf(RowType.class); + Block expectedOutputBlock = createProjectedColumnBlock(input, outputType, (RowType) inputType, dereferences); assertBlockEquals(outputType, actualBlock, expectedOutputBlock); } - private static Block createProjectedColumnBlock(Block data, Type finalType, Type blockType, List dereferences) + private static Block createProjectedColumnBlock(Block data, Type finalType, RowType blockType, List dereferences) { if (dereferences.size() == 0) { return data; @@ -233,14 +235,14 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, Type BlockBuilder builder = finalType.createBlockBuilder(null, data.getPositionCount()); for (int i = 0; i < data.getPositionCount(); i++) { - Type sourceType = blockType; + RowType sourceType = blockType; SqlRow currentData = null; boolean isNull = data.isNull(i); if (!isNull) { - // Get SingleRowBlock corresponding to element at position i - currentData = data.getObject(i, SqlRow.class); + // Get SqlRow corresponding to element at position i + currentData = sourceType.getObject(data, i); } // Apply all dereferences except for the last one, because the type can be different @@ -253,14 +255,14 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, Type int fieldIndex = dereferences.get(j); Block fieldBlock = currentData.getRawFieldBlock(fieldIndex); - RowType rowType = (RowType) sourceType; + RowType rowType = sourceType; int rawIndex = currentData.getRawIndex(); if (fieldBlock.isNull(rawIndex)) { currentData = null; } else { - sourceType = rowType.getFields().get(fieldIndex).getType(); - currentData = fieldBlock.getObject(rawIndex, SqlRow.class); + sourceType = (RowType) rowType.getFields().get(fieldIndex).getType(); + currentData = sourceType.getObject(fieldBlock, rawIndex); } isNull = isNull || (currentData == null); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java index 28159ebf32f9..ad6cdd28cca5 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java @@ -64,12 +64,10 @@ import io.trino.spi.type.SmallintType; import io.trino.spi.type.TimestampType; import io.trino.spi.type.TinyintType; -import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.testing.MaterializedResult; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import java.io.File; import java.util.ArrayList; @@ -128,6 +126,7 @@ import static org.apache.hadoop.hive.metastore.TableType.VIRTUAL_VIEW; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assumptions.abort; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -136,7 +135,6 @@ * See https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default * on ways to set your AWS credentials which will be needed to run this test. */ -@Test(singleThreaded = true) public class TestHiveGlueMetastore extends AbstractTestHiveLocal { @@ -206,7 +204,7 @@ protected AWSGlueAsync getGlueClient() return glueClient; } - @BeforeClass(alwaysRun = true) + @BeforeAll @Override public void initialize() throws Exception @@ -215,11 +213,7 @@ public void initialize() // uncomment to get extra AWS debug information // Logging logging = Logging.initialize(); // logging.setLevel("com.amazonaws.request", Level.DEBUG); - } - @BeforeClass - public void setup() - { metastore = new HiveMetastoreClosure(metastoreClient); glueClient = AWSGlueAsyncClientBuilder.defaultClient(); } @@ -277,12 +271,14 @@ public void cleanupOrphanedDatabases() }); } + @Test @Override public void testRenameTable() { // rename table is not yet supported by Glue } + @Test @Override public void testUpdateTableColumnStatisticsEmptyOptionalFields() { @@ -291,6 +287,7 @@ public void testUpdateTableColumnStatisticsEmptyOptionalFields() // in order to avoid incorrect data we skip writes for statistics with min/max = null } + @Test @Override public void testUpdatePartitionColumnStatisticsEmptyOptionalFields() { @@ -299,6 +296,7 @@ public void testUpdatePartitionColumnStatisticsEmptyOptionalFields() // in order to avoid incorrect data we skip writes for statistics with min/max = null } + @Test @Override public void testUpdateBasicPartitionStatistics() throws Exception @@ -317,6 +315,7 @@ public void testUpdateBasicPartitionStatistics() } } + @Test @Override public void testUpdatePartitionColumnStatistics() throws Exception @@ -338,6 +337,7 @@ public void testUpdatePartitionColumnStatistics() } } + @Test @Override public void testStorePartitionWithStatistics() throws Exception @@ -348,6 +348,7 @@ public void testStorePartitionWithStatistics() testStorePartitionWithStatistics(STATISTICS_PARTITIONED_TABLE_COLUMNS, BASIC_STATISTICS_1, BASIC_STATISTICS_2, BASIC_STATISTICS_1, EMPTY_ROWCOUNT_STATISTICS); } + @Test @Override public void testGetPartitions() throws Exception @@ -949,33 +950,43 @@ public void testGetPartitionsFilterIsNotNull() ImmutableList.of(ImmutableList.of("100"))); } - @Test(dataProvider = "unsupportedNullPushdownTypes") - public void testGetPartitionsFilterUnsupportedIsNull(List columnMetadata, Type type, String partitionValue) + @Test + public void testGetPartitionsFilterUnsupported() throws Exception { - TupleDomain isNullFilter = new PartitionFilterBuilder() - .addDomain(PARTITION_KEY, Domain.onlyNull(type)) - .build(); - List partitionList = new ArrayList<>(); - partitionList.add(partitionValue); - partitionList.add(null); + // Numeric types are unsupported for IS (NOT) NULL predicate pushdown + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_TINYINT, Domain.onlyNull(TinyintType.TINYINT), "127"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_SMALLINT, Domain.onlyNull(SmallintType.SMALLINT), "32767"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_INTEGER, Domain.onlyNull(IntegerType.INTEGER), "2147483647"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_BIGINT, Domain.onlyNull(BigintType.BIGINT), "9223372036854775807"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_DECIMAL, Domain.onlyNull(DECIMAL_TYPE), "12345.12345"); + + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_TINYINT, Domain.notNull(TinyintType.TINYINT), "127"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_SMALLINT, Domain.notNull(SmallintType.SMALLINT), "32767"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_INTEGER, Domain.notNull(IntegerType.INTEGER), "2147483647"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_BIGINT, Domain.notNull(BigintType.BIGINT), "9223372036854775807"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_DECIMAL, Domain.notNull(DECIMAL_TYPE), "12345.12345"); + + // Date and timestamp aren't numeric types, but the pushdown is unsupported because of GlueExpressionUtil.canConvertSqlTypeToStringForGlue + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_DATE, Domain.onlyNull(DateType.DATE), "2022-07-11"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_TIMESTAMP, Domain.onlyNull(TimestampType.TIMESTAMP_MILLIS), "2022-07-11 01:02:03.123"); + + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_DATE, Domain.notNull(DateType.DATE), "2022-07-11"); + testGetPartitionsFilterUnsupported(CREATE_TABLE_COLUMNS_PARTITIONED_TIMESTAMP, Domain.notNull(TimestampType.TIMESTAMP_MILLIS), "2022-07-11 01:02:03.123"); + } - doGetPartitionsFilterTest( - columnMetadata, - PARTITION_KEY, - partitionList, - ImmutableList.of(isNullFilter), - // Currently, we get NULL partition from Glue and filter it in our side because - // (column = '__HIVE_DEFAULT_PARTITION__') on numeric types causes exception on Glue. e.g. 'input string: "__HIVE_D" is not an integer' - ImmutableList.of(ImmutableList.of(partitionValue, GlueExpressionUtil.NULL_STRING))); + @Test + @Override + public void testPartitionSchemaMismatch() + { + abort("tests using existing tables are not supported"); } - @Test(dataProvider = "unsupportedNullPushdownTypes") - public void testGetPartitionsFilterUnsupportedIsNotNull(List columnMetadata, Type type, String partitionValue) + private void testGetPartitionsFilterUnsupported(List columnMetadata, Domain domain, String partitionValue) throws Exception { - TupleDomain isNotNullFilter = new PartitionFilterBuilder() - .addDomain(PARTITION_KEY, Domain.notNull(type)) + TupleDomain isNullFilter = new PartitionFilterBuilder() + .addDomain(PARTITION_KEY, domain) .build(); List partitionList = new ArrayList<>(); partitionList.add(partitionValue); @@ -985,28 +996,12 @@ public void testGetPartitionsFilterUnsupportedIsNotNull(List col columnMetadata, PARTITION_KEY, partitionList, - ImmutableList.of(isNotNullFilter), + ImmutableList.of(isNullFilter), // Currently, we get NULL partition from Glue and filter it in our side because - // (column <> '__HIVE_DEFAULT_PARTITION__') on numeric types causes exception on Glue. e.g. 'input string: "__HIVE_D" is not an integer' + // (column '__HIVE_DEFAULT_PARTITION__') on numeric types causes exception on Glue. e.g. 'input string: "__HIVE_D" is not an integer' ImmutableList.of(ImmutableList.of(partitionValue, GlueExpressionUtil.NULL_STRING))); } - @DataProvider - public Object[][] unsupportedNullPushdownTypes() - { - return new Object[][] { - // Numeric types are unsupported for IS (NOT) NULL predicate pushdown - {CREATE_TABLE_COLUMNS_PARTITIONED_TINYINT, TinyintType.TINYINT, "127"}, - {CREATE_TABLE_COLUMNS_PARTITIONED_SMALLINT, SmallintType.SMALLINT, "32767"}, - {CREATE_TABLE_COLUMNS_PARTITIONED_INTEGER, IntegerType.INTEGER, "2147483647"}, - {CREATE_TABLE_COLUMNS_PARTITIONED_BIGINT, BigintType.BIGINT, "9223372036854775807"}, - {CREATE_TABLE_COLUMNS_PARTITIONED_DECIMAL, DECIMAL_TYPE, "12345.12345"}, - // Date and timestamp aren't numeric types, but the pushdown is unsupported because of GlueExpressionUtil.canConvertSqlTypeToStringForGlue - {CREATE_TABLE_COLUMNS_PARTITIONED_DATE, DateType.DATE, "2022-07-11"}, - {CREATE_TABLE_COLUMNS_PARTITIONED_TIMESTAMP, TimestampType.TIMESTAMP_MILLIS, "2022-07-11 01:02:03.123"}, - }; - } - @Test public void testGetPartitionsFilterEqualsAndIsNotNull() throws Exception @@ -1342,6 +1337,7 @@ public void testInvalidColumnStatisticsMetadata() } } + @Test @Override public void testPartitionColumnProperties() { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java index eef05f9b95e5..5f0df158a571 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java @@ -16,7 +16,7 @@ import io.trino.plugin.hive.HiveQueryRunner; import io.trino.testing.BaseComplexTypesPredicatePushDownTest; import io.trino.testing.QueryRunner; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.testing.TestingNames.randomNameSuffix; import static org.assertj.core.api.Assertions.assertThat; diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java index 64b4646b608e..d47f123f9e1d 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java @@ -24,7 +24,6 @@ import io.trino.plugin.hive.metastore.HiveColumnStatistics; import io.trino.plugin.hive.metastore.IntegerStatistics; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.statistics.ComputedStatistics; import io.trino.spi.statistics.TableStatisticType; import io.trino.spi.type.Type; @@ -36,7 +35,6 @@ import java.util.Optional; import java.util.OptionalDouble; import java.util.OptionalLong; -import java.util.function.Function; import static io.trino.plugin.hive.HiveBasicStatistics.createEmptyStatistics; import static io.trino.plugin.hive.HiveBasicStatistics.createZeroStatistics; @@ -57,6 +55,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TypeUtils.writeNativeValue; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.Float.floatToIntBits; import static org.assertj.core.api.Assertions.assertThat; @@ -318,20 +317,13 @@ public void testMergeHiveColumnStatisticsMap() @Test public void testFromComputedStatistics() { - Function singleIntegerValueBlock = value -> - { - BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, 1); - BIGINT.writeLong(blockBuilder, value); - return blockBuilder.build(); - }; - ComputedStatistics statistics = ComputedStatistics.builder(ImmutableList.of(), ImmutableList.of()) - .addTableStatistic(TableStatisticType.ROW_COUNT, singleIntegerValueBlock.apply(5)) - .addColumnStatistic(MIN_VALUE.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(1)) - .addColumnStatistic(MAX_VALUE.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(5)) - .addColumnStatistic(NUMBER_OF_DISTINCT_VALUES.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(5)) - .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(5)) - .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("b_column"), singleIntegerValueBlock.apply(4)) + .addTableStatistic(TableStatisticType.ROW_COUNT, writeNativeValue(BIGINT, 5L)) + .addColumnStatistic(MIN_VALUE.createColumnStatisticMetadata("a_column"), writeNativeValue(INTEGER, 1L)) + .addColumnStatistic(MAX_VALUE.createColumnStatisticMetadata("a_column"), writeNativeValue(INTEGER, 5L)) + .addColumnStatistic(NUMBER_OF_DISTINCT_VALUES.createColumnStatisticMetadata("a_column"), writeNativeValue(BIGINT, 5L)) + .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("a_column"), writeNativeValue(BIGINT, 5L)) + .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("b_column"), writeNativeValue(BIGINT, 4L)) .build(); Map columnTypes = ImmutableMap.of("a_column", INTEGER, "b_column", VARCHAR); diff --git a/plugin/trino-http-event-listener/pom.xml b/plugin/trino-http-event-listener/pom.xml index 6f31bb2c2612..88c934c902ac 100644 --- a/plugin/trino-http-event-listener/pom.xml +++ b/plugin/trino-http-event-listener/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-hudi/pom.xml b/plugin/trino-hudi/pom.xml index 6b6b6e3b84b3..d61a852bffc2 100644 --- a/plugin/trino-hudi/pom.xml +++ b/plugin/trino-hudi/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/BaseHudiConnectorSmokeTest.java b/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/BaseHudiConnectorSmokeTest.java index b0e0a6303425..4ea065e010c9 100644 --- a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/BaseHudiConnectorSmokeTest.java +++ b/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/BaseHudiConnectorSmokeTest.java @@ -25,30 +25,19 @@ public abstract class BaseHudiConnectorSmokeTest @Override protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { - switch (connectorBehavior) { - case SUPPORTS_INSERT: - case SUPPORTS_DELETE: - case SUPPORTS_UPDATE: - case SUPPORTS_MERGE: - return false; - - case SUPPORTS_CREATE_SCHEMA: - return false; - - case SUPPORTS_CREATE_TABLE: - case SUPPORTS_RENAME_TABLE: - return false; - - case SUPPORTS_CREATE_VIEW: - case SUPPORTS_CREATE_MATERIALIZED_VIEW: - return false; - - case SUPPORTS_COMMENT_ON_COLUMN: - return false; - - default: - return super.hasBehavior(connectorBehavior); - } + return switch (connectorBehavior) { + case SUPPORTS_INSERT, + SUPPORTS_DELETE, + SUPPORTS_UPDATE, + SUPPORTS_MERGE, + SUPPORTS_CREATE_SCHEMA, + SUPPORTS_CREATE_TABLE, + SUPPORTS_RENAME_TABLE, + SUPPORTS_CREATE_VIEW, + SUPPORTS_CREATE_MATERIALIZED_VIEW, + SUPPORTS_COMMENT_ON_COLUMN -> false; + default -> super.hasBehavior(connectorBehavior); + }; } @Test diff --git a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiCopyOnWriteConnectorSmokeTest.java b/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiCopyOnWriteConnectorSmokeTest.java deleted file mode 100644 index 8d09480370f7..000000000000 --- a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiCopyOnWriteConnectorSmokeTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hudi; - -import com.google.common.collect.ImmutableMap; -import io.trino.plugin.hudi.testing.TpchHudiTablesInitializer; -import io.trino.testing.QueryRunner; - -import static io.trino.plugin.hudi.HudiQueryRunner.createHudiQueryRunner; -import static io.trino.plugin.hudi.testing.HudiTestUtils.COLUMNS_TO_HIDE; -import static org.apache.hudi.common.model.HoodieTableType.COPY_ON_WRITE; - -public class TestHudiCopyOnWriteConnectorSmokeTest - extends BaseHudiConnectorSmokeTest -{ - @Override - protected QueryRunner createQueryRunner() - throws Exception - { - return createHudiQueryRunner( - ImmutableMap.of(), - ImmutableMap.of("hudi.columns-to-hide", COLUMNS_TO_HIDE), - new TpchHudiTablesInitializer(COPY_ON_WRITE, REQUIRED_TPCH_TABLES)); - } -} diff --git a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiMergeOnReadConnectorSmokeTest.java b/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiMergeOnReadConnectorSmokeTest.java deleted file mode 100644 index 16d24bdbe759..000000000000 --- a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiMergeOnReadConnectorSmokeTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hudi; - -import com.google.common.collect.ImmutableMap; -import io.trino.plugin.hudi.testing.TpchHudiTablesInitializer; -import io.trino.testing.QueryRunner; - -import static io.trino.plugin.hudi.HudiQueryRunner.createHudiQueryRunner; -import static io.trino.plugin.hudi.testing.HudiTestUtils.COLUMNS_TO_HIDE; -import static org.apache.hudi.common.model.HoodieTableType.MERGE_ON_READ; - -public class TestHudiMergeOnReadConnectorSmokeTest - extends BaseHudiConnectorSmokeTest -{ - @Override - protected QueryRunner createQueryRunner() - throws Exception - { - return createHudiQueryRunner( - ImmutableMap.of(), - ImmutableMap.of("hudi.columns-to-hide", COLUMNS_TO_HIDE), - new TpchHudiTablesInitializer(MERGE_ON_READ, REQUIRED_TPCH_TABLES)); - } -} diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index 418547ff5a96..51bb22aa6860 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -194,13 +194,13 @@ org.apache.datasketches datasketches-java - 3.3.0 + 4.2.0 org.apache.datasketches datasketches-memory - 2.1.0 + 2.2.0 diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java index 367b8f99b9e4..83fc9e7fa5ad 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java @@ -196,7 +196,7 @@ public static Object toIcebergAvroObject(Type type, org.apache.iceberg.types.Typ Type elementType = arrayType.getElementType(); org.apache.iceberg.types.Type elementIcebergType = icebergType.asListType().elementType(); - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = arrayType.getObject(block, position); List list = new ArrayList<>(arrayBlock.getPositionCount()); for (int i = 0; i < arrayBlock.getPositionCount(); i++) { @@ -212,7 +212,7 @@ public static Object toIcebergAvroObject(Type type, org.apache.iceberg.types.Typ org.apache.iceberg.types.Type keyIcebergType = icebergType.asMapType().keyType(); org.apache.iceberg.types.Type valueIcebergType = icebergType.asMapType().valueType(); - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java index 1180e9c2b582..36dc301f6b40 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.iceberg.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -27,7 +27,7 @@ import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; -import org.apache.datasketches.Family; +import org.apache.datasketches.common.Family; import org.apache.datasketches.theta.SetOperation; import org.apache.datasketches.theta.Sketch; import org.apache.datasketches.theta.Union; @@ -53,7 +53,7 @@ private IcebergThetaSketchForStats() {} @InputFunction @TypeParameter("T") - public static void input(@TypeParameter("T") Type type, @AggregationState DataSketchState state, @BlockPosition @SqlType("T") Block block, @BlockIndex int index) + public static void input(@TypeParameter("T") Type type, @AggregationState DataSketchState state, @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int index) { verify(!block.isNull(index), "Input function is not expected to be called on a NULL input"); diff --git a/plugin/trino-ignite/pom.xml b/plugin/trino-ignite/pom.xml index b74108d8bf91..38298e071b35 100644 --- a/plugin/trino-ignite/pom.xml +++ b/plugin/trino-ignite/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -139,6 +139,12 @@ runtime + + io.airlift + junit-extensions + test + + io.airlift testing @@ -202,6 +208,12 @@ test + + org.junit.jupiter + junit-jupiter-api + test + + org.testcontainers jdbc diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index d385d4e8174a..ea579992270b 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -164,7 +164,6 @@ public IgniteClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") - .map("$like(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape") .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, diff --git a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteCaseInsensitiveMapping.java b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteCaseInsensitiveMapping.java index daa30c28f8e5..e0cd4da7fff4 100644 --- a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteCaseInsensitiveMapping.java +++ b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteCaseInsensitiveMapping.java @@ -20,8 +20,7 @@ import io.trino.testing.MaterializedRow; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; import java.util.List; @@ -33,12 +32,12 @@ import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.abort; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestIgniteCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { @@ -76,6 +75,7 @@ protected String quoted(String name) return identifierQuote + name + identifierQuote; } + @Test @Override public void testNonLowerCaseSchemaName() throws Exception @@ -101,6 +101,7 @@ public void testNonLowerCaseSchemaName() } } + @Test @Override public void testNonLowerCaseTableName() throws Exception @@ -144,6 +145,7 @@ public void testNonLowerCaseTableName() } } + @Test @Override public void testSchemaNameClash() throws Exception @@ -167,6 +169,7 @@ public void testSchemaNameClash() } } + @Test @Override public void testTableNameClash() throws Exception @@ -193,34 +196,39 @@ public void testTableNameClash() } } + @Test @Override public void testTableNameClashWithRuleMapping() { - throw new SkipException("Not support creating Ignite custom schema"); + abort("Not support creating Ignite custom schema"); } + @Test @Override public void testSchemaNameClashWithRuleMapping() { - throw new SkipException("Not support creating Ignite custom schema"); + abort("Not support creating Ignite custom schema"); } + @Test @Override public void testSchemaAndTableNameRuleMapping() { - throw new SkipException("Not support creating Ignite custom schema"); + abort("Not support creating Ignite custom schema"); } + @Test @Override public void testSchemaNameRuleMapping() { - throw new SkipException("Not support creating Ignite custom schema"); + abort("Not support creating Ignite custom schema"); } + @Test @Override public void testTableNameRuleMapping() { - throw new SkipException("Not support creating Ignite custom schema"); + abort("Not support creating Ignite custom schema"); } @Override diff --git a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java index aaddab434b2b..94ef4dd034b0 100644 --- a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java +++ b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.TableScanNode; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.SqlExecutor; @@ -31,6 +33,7 @@ import static com.google.common.base.Strings.nullToEmpty; import static io.trino.plugin.ignite.IgniteQueryRunner.createIgniteQueryRunner; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; @@ -96,6 +99,34 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) }; } + @Test + public void testLikeWithEscape() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_like_with_escape", + "(id int, a varchar(4))", + List.of( + "1, 'abce'", + "2, 'abcd'", + "3, 'a%de'"))) { + String tableName = testTable.getName(); + + assertThat(query("SELECT * FROM " + tableName + " WHERE a LIKE 'a%'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + tableName + " WHERE a LIKE '%c%' ESCAPE '\\'")) + .matches("VALUES (1, 'abce'), (2, 'abcd')") + .isNotFullyPushedDown(node(FilterNode.class, node(TableScanNode.class))); + + assertThat(query("SELECT * FROM " + tableName + " WHERE a LIKE 'a\\%d%' ESCAPE '\\'")) + .matches("VALUES (3, 'a%de')") + .isNotFullyPushedDown(node(FilterNode.class, node(TableScanNode.class))); + + assertThatThrownBy(() -> onRemoteDatabase().execute("SELECT * FROM " + tableName + " WHERE a LIKE 'a%' ESCAPE '\\'")) + .hasMessageContaining("Failed to execute statement"); + } + } + @Test public void testDatabaseMetadataSearchEscapedWildCardCharacters() { diff --git a/plugin/trino-jmx/pom.xml b/plugin/trino-jmx/pom.xml index 8a7548a4ad82..5a0527161fd2 100644 --- a/plugin/trino-jmx/pom.xml +++ b/plugin/trino-jmx/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index aa29f6786cfb..7794ee0dbdd4 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ClassLoaderSafeSchemaRegistryClient.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ClassLoaderSafeSchemaRegistryClient.java index 66ebf7f015d1..b564a62a08c5 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ClassLoaderSafeSchemaRegistryClient.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ClassLoaderSafeSchemaRegistryClient.java @@ -13,11 +13,16 @@ */ package io.trino.plugin.kafka.schema.confluent; +import com.google.common.base.Ticker; import io.confluent.kafka.schemaregistry.ParsedSchema; import io.confluent.kafka.schemaregistry.client.SchemaMetadata; import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; +import io.confluent.kafka.schemaregistry.client.rest.entities.Config; +import io.confluent.kafka.schemaregistry.client.rest.entities.Metadata; +import io.confluent.kafka.schemaregistry.client.rest.entities.RuleSet; import io.confluent.kafka.schemaregistry.client.rest.entities.SchemaReference; import io.confluent.kafka.schemaregistry.client.rest.entities.SubjectVersion; +import io.confluent.kafka.schemaregistry.client.rest.entities.requests.RegisterSchemaResponse; import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException; import io.trino.spi.classloader.ThreadContextClassLoader; import org.apache.avro.Schema; @@ -496,4 +501,91 @@ public void reset() delegate.reset(); } } + + @Override + public String tenant() + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.tenant(); + } + } + + @Override + public Ticker ticker() + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.ticker(); + } + } + + @Override + public Optional parseSchema(String schemaType, String schemaString, List references, Metadata metadata, RuleSet ruleSet) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.parseSchema(schemaType, schemaString, references, metadata, ruleSet); + } + } + + @Override + public RegisterSchemaResponse registerWithResponse(String subject, ParsedSchema schema, boolean normalize) + throws RestClientException, IOException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.registerWithResponse(subject, schema, normalize); + } + } + + @Override + public SchemaMetadata getLatestWithMetadata(String subject, Map metadata, boolean lookupDeletedSchema) + throws RestClientException, IOException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getLatestWithMetadata(subject, metadata, lookupDeletedSchema); + } + } + + @Override + public List testCompatibilityVerbose(String subject, ParsedSchema schema, boolean normalize) + throws RestClientException, IOException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.testCompatibilityVerbose(subject, schema, normalize); + } + } + + @Override + public Config updateConfig(String subject, Config config) + throws RestClientException, IOException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.updateConfig(subject, config); + } + } + + @Override + public Config getConfig(String subject) + throws IOException, RestClientException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getConfig(subject); + } + } + + @Override + public void deleteConfig(String subject) + throws IOException, RestClientException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.deleteConfig(subject); + } + } + + @Override + public String setMode(String mode, String subject, boolean force) + throws IOException, RestClientException + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.setMode(mode, subject, force); + } + } } diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java index be6517b35457..e66d0752d1df 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java @@ -188,6 +188,18 @@ public String schemaType() return "PROTOBUF"; } + @Override + public Optional parseSchema(Schema schema, boolean isNew) + { + return SchemaProvider.super.parseSchema(schema, isNew); + } + + @Override + public Optional parseSchema(Schema schema, boolean isNew, boolean normalize) + { + return SchemaProvider.super.parseSchema(schema, isNew, normalize); + } + @Override public void configure(Map configuration) { @@ -202,9 +214,21 @@ public Optional parseSchema(String schema, List r } @Override - public ParsedSchema parseSchemaOrElseThrow(Schema schema, boolean isNew) + public Optional parseSchema(String schemaString, List references, boolean isNew, boolean normalize) + { + return SchemaProvider.super.parseSchema(schemaString, references, isNew, normalize); + } + + @Override + public Optional parseSchema(String schemaString, List references) + { + return SchemaProvider.super.parseSchema(schemaString, references); + } + + @Override + public ParsedSchema parseSchemaOrElseThrow(Schema schema, boolean isNew, boolean normalize) { - return delegate.get().parseSchemaOrElseThrow(schema, isNew); + return delegate.get().parseSchemaOrElseThrow(schema, isNew, normalize); } private SchemaProvider create() diff --git a/plugin/trino-kinesis/pom.xml b/plugin/trino-kinesis/pom.xml index 137564b2a70d..3039079be291 100644 --- a/plugin/trino-kinesis/pom.xml +++ b/plugin/trino-kinesis/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-kudu/pom.xml b/plugin/trino-kudu/pom.xml index f516bc4cc9d0..bb06ee1b7bd9 100644 --- a/plugin/trino-kudu/pom.xml +++ b/plugin/trino-kudu/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-local-file/pom.xml b/plugin/trino-local-file/pom.xml index 7c0285277818..96d0aea9ff64 100644 --- a/plugin/trino-local-file/pom.xml +++ b/plugin/trino-local-file/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mariadb/pom.xml b/plugin/trino-mariadb/pom.xml index 606437f9ed4a..0275a745aa4a 100644 --- a/plugin/trino-mariadb/pom.xml +++ b/plugin/trino-mariadb/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -105,6 +105,12 @@ runtime + + io.airlift + junit-extensions + test + + io.airlift testing @@ -174,6 +180,12 @@ test + + org.junit.jupiter + junit-jupiter-api + test + + org.testcontainers mariadb diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java index e7de87601107..d2ed5a68ed43 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java @@ -84,7 +84,7 @@ protected TestTable createTableWithUnsupportedColumn() "(one bigint, two decimal(50,0), three varchar(10))"); } - @Test + @org.junit.jupiter.api.Test @Override public void testShowColumns() { diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java index b585ce4c59cf..044fa3e281e3 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java @@ -14,10 +14,10 @@ package io.trino.plugin.mariadb; import io.trino.testing.MaterializedRow; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static java.lang.String.format; +import static org.junit.jupiter.api.Assumptions.abort; public abstract class BaseMariaDbTableIndexStatisticsTest extends BaseMariaDbTableStatisticsTest @@ -49,7 +49,7 @@ protected void gatherStats(String tableName) public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -57,7 +57,7 @@ public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled() public void testStatsWithPredicatePushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -65,7 +65,7 @@ public void testStatsWithPredicatePushdown() public void testStatsWithVarcharPredicatePushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -73,7 +73,7 @@ public void testStatsWithVarcharPredicatePushdown() public void testStatsWithLimitPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -81,7 +81,7 @@ public void testStatsWithLimitPushdown() public void testStatsWithTopNPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -89,7 +89,7 @@ public void testStatsWithTopNPushdown() public void testStatsWithDistinctPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -97,7 +97,7 @@ public void testStatsWithDistinctPushdown() public void testStatsWithDistinctLimitPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -105,7 +105,7 @@ public void testStatsWithDistinctLimitPushdown() public void testStatsWithAggregationPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -113,7 +113,7 @@ public void testStatsWithAggregationPushdown() public void testStatsWithSimpleJoinPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } @Test @@ -121,6 +121,6 @@ public void testStatsWithSimpleJoinPushdown() public void testStatsWithJoinPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } } diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java index 45d230e6e1c6..5078044ce4fb 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java @@ -20,8 +20,7 @@ import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; import org.assertj.core.api.AbstractDoubleAssert; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.List; @@ -42,6 +41,7 @@ import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.withinPercentage; +import static org.junit.jupiter.api.Assumptions.abort; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; @@ -223,14 +223,14 @@ public void testNullsFraction() @Override public void testAverageColumnLength() { - throw new SkipException("MariaDB connector does not report average column length"); + abort("MariaDB connector does not report average column length"); } @Test @Override public void testPartitionedTable() { - throw new SkipException("Not implemented"); // TODO + abort("Not implemented"); // TODO } @Test @@ -259,12 +259,11 @@ public void testView() @Override public void testMaterializedView() { - throw new SkipException(""); // TODO is there a concept like materialized view in MariaDB? + abort(""); // TODO is there a concept like materialized view in MariaDB? } - @Test(dataProvider = "testCaseColumnNamesDataProvider") @Override - public void testCaseColumnNames(String tableName) + protected void testCaseColumnNames(String tableName) { executeInMariaDb(("" + "CREATE TABLE " + tableName + " " + diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbCaseInsensitiveMapping.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbCaseInsensitiveMapping.java index 369a1bd74bad..b8a9241c7a0a 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbCaseInsensitiveMapping.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbCaseInsensitiveMapping.java @@ -18,7 +18,7 @@ import io.trino.plugin.jdbc.BaseCaseInsensitiveMappingTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; @@ -28,7 +28,6 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestMariaDbCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { diff --git a/plugin/trino-memory/pom.xml b/plugin/trino-memory/pom.xml index b77248ec7702..f96a7e7027bf 100644 --- a/plugin/trino-memory/pom.xml +++ b/plugin/trino-memory/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-ml/pom.xml b/plugin/trino-ml/pom.xml index d8ee497b7572..5b5669e133ba 100644 --- a/plugin/trino-ml/pom.xml +++ b/plugin/trino-ml/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java b/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java index 63ecbdde38ea..a8d023df4621 100644 --- a/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java +++ b/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -45,7 +46,9 @@ protected ModelType(TypeSignature signature) @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/plugin/trino-mongodb/pom.xml b/plugin/trino-mongodb/pom.xml index 8bd56e0bea29..344aad107cee 100644 --- a/plugin/trino-mongodb/pom.xml +++ b/plugin/trino-mongodb/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java index f317a72bca24..9639d4546d57 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java @@ -202,7 +202,7 @@ private Object getObjectValue(Type type, Block block, int position) if (type instanceof ArrayType arrayType) { Type elementType = arrayType.getElementType(); - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = arrayType.getObject(block, position); List list = new ArrayList<>(arrayBlock.getPositionCount()); for (int i = 0; i < arrayBlock.getPositionCount(); i++) { @@ -216,7 +216,7 @@ private Object getObjectValue(Type type, Block block, int position) Type keyType = mapType.getKeyType(); Type valueType = mapType.getValueType(); - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int size = sqlMap.getSize(); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index e5c8d7723b3a..8f5a36924d8b 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -138,6 +138,8 @@ public class MongoSession private static final String FIELDS_TYPE_KEY = "type"; private static final String FIELDS_HIDDEN_KEY = "hidden"; + private static final Document EMPTY_DOCUMENT = new Document(); + private static final String AND_OP = "$and"; private static final String OR_OP = "$or"; @@ -587,16 +589,17 @@ static Document buildFilter(MongoTableHandle table) @VisibleForTesting static Document buildQuery(TupleDomain tupleDomain) { - Document query = new Document(); + ImmutableList.Builder queryBuilder = ImmutableList.builder(); if (tupleDomain.getDomains().isPresent()) { for (Map.Entry entry : tupleDomain.getDomains().get().entrySet()) { MongoColumnHandle column = (MongoColumnHandle) entry.getKey(); Optional predicate = buildPredicate(column, entry.getValue()); - predicate.ifPresent(query::putAll); + predicate.ifPresent(queryBuilder::add); } } - return query; + List query = queryBuilder.build(); + return query.isEmpty() ? EMPTY_DOCUMENT : andPredicate(query); } private static Optional buildPredicate(MongoColumnHandle column, Domain domain) diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java index 7124ec9ea47f..d9dcf28a849c 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -68,13 +69,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position } // TODO: There's no way to represent string value of a custom type - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index c8bf9d65cbdf..7291a08ca2cd 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -378,6 +378,20 @@ public void testPredicatePushdownCharWithPaddedSpace() } } + @Test + public void testPredicatePushdownMultipleNotEquals() + { + // Regression test for https://github.com/trinodb/trino/issues/19404 + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_predicate_pushdown_with_multiple_not_equals", + "(id, value) AS VALUES (1, 10), (2, 20), (3, 30)")) { + assertThat(query("SELECT * FROM " + table.getName() + " WHERE id != 1 AND value != 20")) + .matches("VALUES (3, 30)") + .isFullyPushedDown(); + } + } + @Test public void testHighPrecisionDecimalPredicate() { diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java index d0611b3f0acd..bfc497d769a7 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java @@ -86,8 +86,9 @@ public void testBuildQuery() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document() - .append(COL1.getBaseName(), new Document().append("$gt", 100L).append("$lte", 200L)) - .append(COL2.getBaseName(), new Document("$eq", "a value")); + .append("$and", ImmutableList.of( + new Document(COL1.getBaseName(), new Document().append("$gt", 100L).append("$lte", 200L)), + new Document(COL2.getBaseName(), new Document("$eq", "a value")))); assertEquals(query, expected); } @@ -100,8 +101,9 @@ public void testBuildQueryStringType() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document() - .append(COL3.getBaseName(), new Document().append("$gt", "hello").append("$lte", "world")) - .append(COL2.getBaseName(), new Document("$gte", "a value")); + .append("$and", ImmutableList.of( + new Document(COL3.getBaseName(), new Document().append("$gt", "hello").append("$lte", "world")), + new Document(COL2.getBaseName(), new Document("$gte", "a value")))); assertEquals(query, expected); } @@ -161,10 +163,11 @@ public void testBuildQueryNestedField() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document() - .append("$or", asList( - new Document(COL5.getQualifiedName(), new Document("$gt", 200L)), - new Document(COL5.getQualifiedName(), new Document("$eq", null)))) - .append(COL6.getQualifiedName(), new Document("$eq", "a value")); + .append("$and", ImmutableList.of( + new Document("$or", asList( + new Document(COL5.getQualifiedName(), new Document("$gt", 200L)), + new Document(COL5.getQualifiedName(), new Document("$eq", null)))), + new Document(COL6.getQualifiedName(), new Document("$eq", "a value")))); assertEquals(query, expected); } diff --git a/plugin/trino-mysql-event-listener/pom.xml b/plugin/trino-mysql-event-listener/pom.xml index 7e1c47d73cf8..57e8e8dc255b 100644 --- a/plugin/trino-mysql-event-listener/pom.xml +++ b/plugin/trino-mysql-event-listener/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mysql/pom.xml b/plugin/trino-mysql/pom.xml index a78c2a325ec0..cbcc0ca8cc90 100644 --- a/plugin/trino-mysql/pom.xml +++ b/plugin/trino-mysql/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlTableStatisticsIndexStatisticsTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlTableStatisticsIndexStatisticsTest.java index 81662c8d7a58..1773f945d134 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlTableStatisticsIndexStatisticsTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlTableStatisticsIndexStatisticsTest.java @@ -14,9 +14,10 @@ package io.trino.plugin.mysql; import io.trino.testing.MaterializedRow; -import org.testng.SkipException; +import org.junit.jupiter.api.Test; import static java.lang.String.format; +import static org.junit.jupiter.api.Assumptions.abort; public abstract class BaseMySqlTableStatisticsIndexStatisticsTest extends BaseTestMySqlTableStatisticsTest @@ -42,73 +43,83 @@ protected void gatherStats(String tableName) executeInMysql("ANALYZE TABLE " + tableName.replace("\"", "`")); } + @Test @Override public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithPredicatePushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithVarcharPredicatePushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithLimitPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithTopNPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithDistinctPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithDistinctLimitPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithAggregationPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithSimpleJoinPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } + @Test @Override public void testStatsWithJoinPushdown() { // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MySQL, with permissive approximate assertions - throw new SkipException("Test to be implemented"); + abort("Test to be implemented"); } } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java index 3dce3fcd9108..16c38585aa9f 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java @@ -20,8 +20,7 @@ import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; import org.assertj.core.api.AbstractDoubleAssert; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.List; @@ -42,6 +41,7 @@ import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.withinPercentage; +import static org.junit.jupiter.api.Assumptions.abort; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; @@ -226,14 +226,14 @@ public void testNullsFraction() @Test public void testAverageColumnLength() { - throw new SkipException("MySQL connector does not report average column length"); + abort("MySQL connector does not report average column length"); } @Override @Test public void testPartitionedTable() { - throw new SkipException("Not implemented"); // TODO + abort("Not implemented"); // TODO } @Override @@ -262,12 +262,11 @@ public void testView() @Test public void testMaterializedView() { - throw new SkipException(""); // TODO is there a concept like materialized view in MySQL? + abort(""); // TODO is there a concept like materialized view in MySQL? } @Override - @Test(dataProvider = "testCaseColumnNamesDataProvider") - public void testCaseColumnNames(String tableName) + protected void testCaseColumnNames(String tableName) { executeInMysql(("" + "CREATE TABLE " + tableName + " " + diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlCaseInsensitiveMapping.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlCaseInsensitiveMapping.java index 89fec21128b8..841112ebf187 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlCaseInsensitiveMapping.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlCaseInsensitiveMapping.java @@ -18,7 +18,7 @@ import io.trino.plugin.jdbc.BaseCaseInsensitiveMappingTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; @@ -29,7 +29,6 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestMySqlCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8Histograms.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8Histograms.java index 017270fe17af..2ee898ef6bbc 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8Histograms.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8Histograms.java @@ -15,8 +15,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.testing.sql.TestTable; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.function.Function; @@ -26,6 +25,7 @@ import static io.trino.testing.sql.TestTable.fromColumns; import static java.lang.String.format; import static java.lang.String.join; +import static org.junit.jupiter.api.Assumptions.abort; public class TestMySqlTableStatisticsMySql8Histograms extends BaseTestMySqlTableStatisticsTest @@ -82,10 +82,11 @@ public void testNumericCornerCases() } } + @Test @Override public void testNotAnalyzed() { - throw new SkipException("MySql8 automatically calculates stats - https://dev.mysql.com/doc/refman/8.0/en/innodb-parameters.html#sysvar_innodb_stats_auto_recalc"); + abort("MySql8 automatically calculates stats - https://dev.mysql.com/doc/refman/8.0/en/innodb-parameters.html#sysvar_innodb_stats_auto_recalc"); } @Override diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8IndexStatistics.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8IndexStatistics.java index 4b18806bc0a4..67d5570736d8 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8IndexStatistics.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlTableStatisticsMySql8IndexStatistics.java @@ -13,7 +13,9 @@ */ package io.trino.plugin.mysql; -import org.testng.SkipException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assumptions.abort; public class TestMySqlTableStatisticsMySql8IndexStatistics extends BaseMySqlTableStatisticsIndexStatisticsTest @@ -23,9 +25,10 @@ public TestMySqlTableStatisticsMySql8IndexStatistics() super("mysql:8.0.30"); } + @Test @Override public void testNotAnalyzed() { - throw new SkipException("MySql8 automatically calculates stats - https://dev.mysql.com/doc/refman/8.0/en/innodb-parameters.html#sysvar_innodb_stats_auto_recalc"); + abort("MySql8 automatically calculates stats - https://dev.mysql.com/doc/refman/8.0/en/innodb-parameters.html#sysvar_innodb_stats_auto_recalc"); } } diff --git a/plugin/trino-oracle/pom.xml b/plugin/trino-oracle/pom.xml index 6ab026fab12c..239d2ae93e67 100644 --- a/plugin/trino-oracle/pom.xml +++ b/plugin/trino-oracle/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java index af36cb507e88..d3c2f545861c 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.oracle; +import com.google.common.base.Throwables; import com.google.inject.Binder; import com.google.inject.Key; import com.google.inject.Module; @@ -26,7 +27,7 @@ import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; -import io.trino.plugin.jdbc.RetryingConnectionFactory; +import io.trino.plugin.jdbc.RetryingConnectionFactory.RetryStrategy; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.function.table.ConnectorTableFunction; @@ -34,6 +35,7 @@ import oracle.jdbc.OracleDriver; import java.sql.SQLException; +import java.sql.SQLRecoverableException; import java.util.Properties; import static com.google.inject.multibindings.Multibinder.newSetBinder; @@ -53,6 +55,7 @@ public void configure(Binder binder) configBinder(binder).bindConfig(OracleConfig.class); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(ORACLE_MAX_LIST_EXPRESSIONS); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, RetryStrategy.class).setBinding().to(OracleRetryStrategy.class).in(Scopes.SINGLETON); } @Provides @@ -76,11 +79,22 @@ public static ConnectionFactory connectionFactory(BaseJdbcConfig config, Credent openTelemetry); } - return new RetryingConnectionFactory(new DriverConnectionFactory( + return new DriverConnectionFactory( new OracleDriver(), config.getConnectionUrl(), connectionProperties, credentialProvider, - openTelemetry)); + openTelemetry); + } + + private static class OracleRetryStrategy + implements RetryStrategy + { + @Override + public boolean isExceptionRecoverable(Throwable exception) + { + return Throwables.getCausalChain(exception).stream() + .anyMatch(SQLRecoverableException.class::isInstance); + } } } diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCaseInsensitiveMapping.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCaseInsensitiveMapping.java index fe5aee0ce1c3..d5b3b2fe0dfa 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCaseInsensitiveMapping.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCaseInsensitiveMapping.java @@ -18,7 +18,7 @@ import io.trino.plugin.jdbc.BaseCaseInsensitiveMappingTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; import java.util.Optional; @@ -32,7 +32,6 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestOracleCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestingOracleServer.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestingOracleServer.java index f73ae90215f3..6bf8e1ed04d9 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestingOracleServer.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestingOracleServer.java @@ -22,7 +22,9 @@ import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.RetryingConnectionFactory; +import io.trino.plugin.jdbc.RetryingConnectionFactory.DefaultRetryStrategy; import io.trino.plugin.jdbc.credential.StaticCredentialProvider; +import io.trino.plugin.jdbc.jmx.StatisticsAwareConnectionFactory; import io.trino.testing.ResourcePresence; import oracle.jdbc.OracleDriver; import org.testcontainers.containers.OracleContainer; @@ -125,11 +127,11 @@ public void execute(String sql, String user, String password) private ConnectionFactory getConnectionFactory(String connectionUrl, String username, String password) { - DriverConnectionFactory connectionFactory = new DriverConnectionFactory( + StatisticsAwareConnectionFactory connectionFactory = new StatisticsAwareConnectionFactory(new DriverConnectionFactory( new OracleDriver(), new BaseJdbcConfig().setConnectionUrl(connectionUrl), - StaticCredentialProvider.of(username, password)); - return new RetryingConnectionFactory(connectionFactory); + StaticCredentialProvider.of(username, password))); + return new RetryingConnectionFactory(connectionFactory, new DefaultRetryStrategy()); } @Override diff --git a/plugin/trino-password-authenticators/pom.xml b/plugin/trino-password-authenticators/pom.xml index 1603e6b6f1b2..4e0091e89603 100644 --- a/plugin/trino-password-authenticators/pom.xml +++ b/plugin/trino-password-authenticators/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-phoenix5/pom.xml b/plugin/trino-phoenix5/pom.xml index 5ebe6c0f87cf..d04b1ee897a1 100644 --- a/plugin/trino-phoenix5/pom.xml +++ b/plugin/trino-phoenix5/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java index 852d76723aa4..c18bdc26d041 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java @@ -34,7 +34,6 @@ import io.trino.plugin.jdbc.DynamicFilteringStats; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.ForJdbcDynamicFiltering; -import io.trino.plugin.jdbc.ForLazyConnectionFactory; import io.trino.plugin.jdbc.ForRecordCursor; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcDiagnosticModule; @@ -48,6 +47,7 @@ import io.trino.plugin.jdbc.LazyConnectionFactory; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; import io.trino.plugin.jdbc.QueryBuilder; +import io.trino.plugin.jdbc.RetryingConnectionFactoryModule; import io.trino.plugin.jdbc.ReusableConnectionFactoryModule; import io.trino.plugin.jdbc.StatsCollecting; import io.trino.plugin.jdbc.TypeHandlingJdbcConfig; @@ -97,6 +97,7 @@ public PhoenixClientModule(String catalogName) protected void setup(Binder binder) { install(new RemoteQueryModifierModule()); + install(new RetryingConnectionFactoryModule()); binder.bind(ConnectorSplitManager.class).annotatedWith(ForJdbcDynamicFiltering.class).to(PhoenixSplitManager.class).in(Scopes.SINGLETON); binder.bind(ConnectorSplitManager.class).annotatedWith(ForClassLoaderSafe.class).to(JdbcDynamicFilteringSplitManager.class).in(Scopes.SINGLETON); binder.bind(ConnectorSplitManager.class).to(ClassLoaderSafeConnectorSplitManager.class).in(Scopes.SINGLETON); @@ -130,10 +131,6 @@ protected void setup(Binder binder) binder.bind(ConnectorMetadata.class).annotatedWith(ForClassLoaderSafe.class).to(PhoenixMetadata.class).in(Scopes.SINGLETON); binder.bind(ConnectorMetadata.class).to(ClassLoaderSafeConnectorMetadata.class).in(Scopes.SINGLETON); - binder.bind(ConnectionFactory.class) - .annotatedWith(ForLazyConnectionFactory.class) - .to(Key.get(ConnectionFactory.class, StatsCollecting.class)) - .in(Scopes.SINGLETON); install(conditionalModule( PhoenixConfig.class, PhoenixConfig::isReuseConnection, diff --git a/plugin/trino-pinot/pom.xml b/plugin/trino-pinot/pom.xml index 387978d5ac53..b7029abee19c 100755 --- a/plugin/trino-pinot/pom.xml +++ b/plugin/trino-pinot/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-postgresql/pom.xml b/plugin/trino-postgresql/pom.xml index 8de8df70a8cb..670bb1dd91fe 100644 --- a/plugin/trino-postgresql/pom.xml +++ b/plugin/trino-postgresql/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 82644e0414dc..c4a87c4df937 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -75,6 +75,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.MapBlock; import io.trino.spi.block.SqlMap; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -1306,8 +1307,8 @@ private ObjectReadFunction varcharMapReadFunction() varcharMapType.getValueType().writeSlice(valueBlockBuilder, utf8Slice(entry.getValue())); } } - return varcharMapType.createBlockFromKeyValue(Optional.empty(), new int[] {0, map.size()}, keyBlockBuilder.build(), valueBlockBuilder.build()) - .getObject(0, SqlMap.class); + MapBlock mapBlock = varcharMapType.createBlockFromKeyValue(Optional.empty(), new int[]{0, map.size()}, keyBlockBuilder.build(), valueBlockBuilder.build()); + return varcharMapType.getObject(mapBlock, 0); }); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlCaseInsensitiveMapping.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlCaseInsensitiveMapping.java index 37008b42896b..9e115ab6dbf6 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlCaseInsensitiveMapping.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlCaseInsensitiveMapping.java @@ -18,7 +18,7 @@ import io.trino.plugin.jdbc.BaseCaseInsensitiveMappingTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; @@ -28,7 +28,6 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestPostgreSqlCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java index fc5c1b9739b9..5cc8a3c76021 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java @@ -31,10 +31,8 @@ import io.trino.testing.DistributedQueryRunner; import io.trino.testing.QueryRunner; import io.trino.tpch.TpchTable; -import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; import org.postgresql.Driver; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; import java.util.Optional; import java.util.Properties; @@ -49,7 +47,6 @@ import static io.trino.tpch.TpchTable.REGION; import static java.util.Objects.requireNonNull; -@Test(singleThreaded = true) // inherited from BaseJdbcConnectionCreationTest public class TestPostgreSqlJdbcConnectionCreation extends BaseJdbcConnectionCreationTest { @@ -67,38 +64,30 @@ protected QueryRunner createQueryRunner() return createPostgreSqlQueryRunner(postgreSqlServer, ImmutableList.of(NATION, REGION), connectionFactory); } - @Test(dataProvider = "testCases") - public void testJdbcConnectionCreations(@Language("SQL") String query, int expectedJdbcConnectionsCount, Optional errorMessage) + @Test + public void testJdbcConnectionCreations() { - assertJdbcConnections(query, expectedJdbcConnectionsCount, errorMessage); - } - - @DataProvider - public Object[][] testCases() - { - return new Object[][] { - {"SELECT * FROM nation LIMIT 1", 3, Optional.empty()}, - {"SELECT * FROM nation ORDER BY nationkey LIMIT 1", 3, Optional.empty()}, - {"SELECT * FROM nation WHERE nationkey = 1", 3, Optional.empty()}, - {"SELECT avg(nationkey) FROM nation", 2, Optional.empty()}, - {"SELECT * FROM nation, region", 3, Optional.empty()}, - {"SELECT * FROM nation n, region r WHERE n.regionkey = r.regionkey", 3, Optional.empty()}, - {"SELECT * FROM nation JOIN region USING(regionkey)", 5, Optional.empty()}, - {"SELECT * FROM information_schema.schemata", 1, Optional.empty()}, - {"SELECT * FROM information_schema.tables", 1, Optional.empty()}, - {"SELECT * FROM information_schema.columns", 1, Optional.empty()}, - {"SELECT * FROM nation", 2, Optional.empty()}, - {"SELECT * FROM TABLE (system.query(query => 'SELECT * FROM tpch.nation'))", 2, Optional.empty()}, - {"CREATE TABLE copy_of_nation AS SELECT * FROM nation", 6, Optional.empty()}, - {"INSERT INTO copy_of_nation SELECT * FROM nation", 6, Optional.empty()}, - {"DELETE FROM copy_of_nation WHERE nationkey = 3", 1, Optional.empty()}, - {"UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 1, Optional.empty()}, - {"MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 1, Optional.of(MODIFYING_ROWS_MESSAGE)}, - {"DROP TABLE copy_of_nation", 1, Optional.empty()}, - {"SHOW SCHEMAS", 1, Optional.empty()}, - {"SHOW TABLES", 1, Optional.empty()}, - {"SHOW STATS FOR nation", 2, Optional.empty()}, - }; + assertJdbcConnections("SELECT * FROM nation LIMIT 1", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation ORDER BY nationkey LIMIT 1", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation WHERE nationkey = 1", 3, Optional.empty()); + assertJdbcConnections("SELECT avg(nationkey) FROM nation", 2, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation, region", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation n, region r WHERE n.regionkey = r.regionkey", 3, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation JOIN region USING(regionkey)", 5, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.schemata", 1, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.tables", 1, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.columns", 1, Optional.empty()); + assertJdbcConnections("SELECT * FROM nation", 2, Optional.empty()); + assertJdbcConnections("SELECT * FROM TABLE (system.query(query => 'SELECT * FROM tpch.nation'))", 2, Optional.empty()); + assertJdbcConnections("CREATE TABLE copy_of_nation AS SELECT * FROM nation", 6, Optional.empty()); + assertJdbcConnections("INSERT INTO copy_of_nation SELECT * FROM nation", 6, Optional.empty()); + assertJdbcConnections("DELETE FROM copy_of_nation WHERE nationkey = 3", 1, Optional.empty()); + assertJdbcConnections("UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 1, Optional.empty()); + assertJdbcConnections("MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 1, Optional.of(MODIFYING_ROWS_MESSAGE)); + assertJdbcConnections("DROP TABLE copy_of_nation", 1, Optional.empty()); + assertJdbcConnections("SHOW SCHEMAS", 1, Optional.empty()); + assertJdbcConnections("SHOW TABLES", 1, Optional.empty()); + assertJdbcConnections("SHOW STATS FOR nation", 2, Optional.empty()); } private static DistributedQueryRunner createPostgreSqlQueryRunner( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java index 19f17b246699..6aa9538e503e 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java @@ -20,7 +20,7 @@ import io.trino.testing.sql.TestTable; import org.jdbi.v3.core.HandleConsumer; import org.jdbi.v3.core.Jdbi; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.Properties; @@ -30,6 +30,7 @@ import static io.trino.tpch.TpchTable.ORDERS; import static java.lang.String.format; import static java.util.stream.Collectors.joining; +import static org.junit.jupiter.api.Assumptions.abort; public class TestPostgreSqlTableStatistics extends BaseJdbcTableStatisticsTest @@ -53,31 +54,11 @@ protected QueryRunner createQueryRunner() ImmutableList.of(ORDERS)); } + @Test @Override - @Test(invocationCount = 10, successPercentage = 50) // PostgreSQL can auto-analyze data before we SHOW STATS public void testNotAnalyzed() { - String tableName = "test_stats_not_analyzed"; - assertUpdate("DROP TABLE IF EXISTS " + tableName); - computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); - try { - assertQuery( - "SHOW STATS FOR " + tableName, - "VALUES " + - "('orderkey', null, null, null, null, null, null)," + - "('custkey', null, null, null, null, null, null)," + - "('orderstatus', null, null, null, null, null, null)," + - "('totalprice', null, null, null, null, null, null)," + - "('orderdate', null, null, null, null, null, null)," + - "('orderpriority', null, null, null, null, null, null)," + - "('clerk', null, null, null, null, null, null)," + - "('shippriority', null, null, null, null, null, null)," + - "('comment', null, null, null, null, null, null)," + - "(null, null, null, null, 15000, null, null)"); - } - finally { - assertUpdate("DROP TABLE " + tableName); - } + abort("PostgreSQL analyzes tables automatically"); } @Override @@ -331,8 +312,7 @@ public void testMaterializedView() } @Override - @Test(dataProvider = "testCaseColumnNamesDataProvider") - public void testCaseColumnNames(String tableName) + protected void testCaseColumnNames(String tableName) { executeInPostgres("" + "CREATE TABLE " + tableName + " " + diff --git a/plugin/trino-prometheus/pom.xml b/plugin/trino-prometheus/pom.xml index 29044b909e88..f416867c1793 100644 --- a/plugin/trino-prometheus/pom.xml +++ b/plugin/trino-prometheus/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java index cce1f2a47798..40d160635cf8 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java @@ -262,9 +262,9 @@ else if (type instanceof MapType mapType) { private static Object readObject(Type type, Block block, int position) { - if (type instanceof ArrayType) { - Type elementType = ((ArrayType) type).getElementType(); - return getArrayFromBlock(elementType, block.getObject(position, Block.class)); + if (type instanceof ArrayType arrayType) { + Type elementType = arrayType.getElementType(); + return getArrayFromBlock(elementType, arrayType.getObject(block, position)); } if (type instanceof MapType mapType) { return getMapFromSqlMap(type, mapType.getObject(block, position)); diff --git a/plugin/trino-raptor-legacy/pom.xml b/plugin/trino-raptor-legacy/pom.xml index d8cb5c025e0f..d1e8769b5ca1 100644 --- a/plugin/trino-raptor-legacy/pom.xml +++ b/plugin/trino-raptor-legacy/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-redis/pom.xml b/plugin/trino-redis/pom.xml index 14a0748afb21..761f251bb456 100644 --- a/plugin/trino-redis/pom.xml +++ b/plugin/trino-redis/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -81,13 +81,13 @@ org.apache.commons commons-pool2 - 2.11.1 + 2.12.0 redis.clients jedis - 5.0.1 + 5.0.2 diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index ffea96bcef5d..626a437fd5e6 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -21,7 +21,7 @@ com.amazon.redshift redshift-jdbc42 - 2.1.0.19 + 2.1.0.20 diff --git a/plugin/trino-resource-group-managers/pom.xml b/plugin/trino-resource-group-managers/pom.xml index d0bbc50d79d5..a6e399c12613 100644 --- a/plugin/trino-resource-group-managers/pom.xml +++ b/plugin/trino-resource-group-managers/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-session-property-managers/pom.xml b/plugin/trino-session-property-managers/pom.xml index ef8d4ed19edc..9c57a088ba51 100644 --- a/plugin/trino-session-property-managers/pom.xml +++ b/plugin/trino-session-property-managers/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-singlestore/pom.xml b/plugin/trino-singlestore/pom.xml index 4df3a0fa9c5e..d9dc0f7ee1dd 100644 --- a/plugin/trino-singlestore/pom.xml +++ b/plugin/trino-singlestore/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml @@ -31,7 +31,7 @@ com.singlestore singlestore-jdbc-client - 1.1.9 + 1.2.0 @@ -113,6 +113,12 @@ runtime + + io.airlift + junit-extensions + test + + io.airlift testing @@ -182,6 +188,12 @@ test + + org.junit.jupiter + junit-jupiter-api + test + + org.testcontainers jdbc diff --git a/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestSingleStoreCaseInsensitiveMapping.java b/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestSingleStoreCaseInsensitiveMapping.java index 09c23bcdc1f2..698ad64bb05e 100644 --- a/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestSingleStoreCaseInsensitiveMapping.java +++ b/plugin/trino-singlestore/src/test/java/io/trino/plugin/singlestore/TestSingleStoreCaseInsensitiveMapping.java @@ -18,7 +18,7 @@ import io.trino.plugin.jdbc.BaseCaseInsensitiveMappingTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; @@ -28,7 +28,6 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestSingleStoreCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { diff --git a/plugin/trino-sqlserver/pom.xml b/plugin/trino-sqlserver/pom.xml index 582218adf318..bdfef843ff69 100644 --- a/plugin/trino-sqlserver/pom.xml +++ b/plugin/trino-sqlserver/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerCaseInsensitiveMapping.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerCaseInsensitiveMapping.java index b306bd155e92..ca60f511f667 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerCaseInsensitiveMapping.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerCaseInsensitiveMapping.java @@ -18,7 +18,7 @@ import io.trino.plugin.jdbc.BaseCaseInsensitiveMappingTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.SqlExecutor; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; import java.sql.Connection; @@ -33,7 +33,6 @@ // With case-insensitive-name-matching enabled colliding schema/table names are considered as errors. // Some tests here create colliding names which can cause any other concurrent test to fail. -@Test(singleThreaded = true) public class TestSqlServerCaseInsensitiveMapping extends BaseCaseInsensitiveMappingTest { diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java index b2ce0f624108..16660425161b 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java @@ -20,8 +20,7 @@ import io.trino.testing.sql.TestTable; import org.jdbi.v3.core.Handle; import org.jdbi.v3.core.Jdbi; -import org.testng.SkipException; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.Map; @@ -31,6 +30,7 @@ import static io.trino.testing.sql.TestTable.fromColumns; import static io.trino.tpch.TpchTable.ORDERS; import static java.lang.String.format; +import static org.junit.jupiter.api.Assumptions.abort; public class TestSqlServerTableStatistics extends BaseJdbcTableStatisticsTest @@ -211,7 +211,7 @@ public void testAverageColumnLength() @Test public void testPartitionedTable() { - throw new SkipException("Not implemented"); // TODO + abort("Not implemented"); // TODO } @Override @@ -236,10 +236,11 @@ public void testView() } } + @Test @Override public void testMaterializedView() { - throw new SkipException("see testIndexedView"); + abort("see testIndexedView"); } @Test @@ -275,8 +276,7 @@ public void testIndexedView() // materialized view } @Override - @Test(dataProvider = "testCaseColumnNamesDataProvider") - public void testCaseColumnNames(String tableName) + protected void testCaseColumnNames(String tableName) { sqlServer.execute("" + "SELECT " + diff --git a/plugin/trino-teradata-functions/pom.xml b/plugin/trino-teradata-functions/pom.xml index 323860d4b1f9..5e02fbe1792b 100644 --- a/plugin/trino-teradata-functions/pom.xml +++ b/plugin/trino-teradata-functions/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift-api/pom.xml b/plugin/trino-thrift-api/pom.xml index a46fd7ee5778..f04e5d621dfe 100644 --- a/plugin/trino-thrift-api/pom.xml +++ b/plugin/trino-thrift-api/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift-testing-server/pom.xml b/plugin/trino-thrift-testing-server/pom.xml index 86f8430cdb06..cc63339e2ce5 100644 --- a/plugin/trino-thrift-testing-server/pom.xml +++ b/plugin/trino-thrift-testing-server/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-thrift/pom.xml b/plugin/trino-thrift/pom.xml index 59417a1d2dd3..357b3963ee0c 100644 --- a/plugin/trino-thrift/pom.xml +++ b/plugin/trino-thrift/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-tpcds/pom.xml b/plugin/trino-tpcds/pom.xml index 54c8f0d95796..ab9c114993f7 100644 --- a/plugin/trino-tpcds/pom.xml +++ b/plugin/trino-tpcds/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/plugin/trino-tpch/pom.xml b/plugin/trino-tpch/pom.xml index 8c02628f7a0b..26689203ac1e 100644 --- a/plugin/trino-tpch/pom.xml +++ b/plugin/trino-tpch/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index ebdad4d1c21f..55e50d7382b1 100644 --- a/pom.xml +++ b/pom.xml @@ -10,7 +10,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT pom ${project.artifactId} @@ -159,21 +159,21 @@ 1.11.3 ${dep.airlift.version} 2.1.1 - 1.12.566 - 2.21.0 - 0.12.2 + 1.12.560 + 2.21.4 + 0.12.3 21.9.0.0 1.21 201 2.2.17 - 1.6.11 + 1.6.12 1.9.10 1.43.3 2.22.0 1.19.1 1.0.8 - 7.3.1 - 3.3.2 + 7.4.1 + 3.6.0 4.17.0 8.5.6 1.3.1 @@ -187,7 +187,7 @@ 4.13.1 9.6 - 86 + 87 - 430-SNAPSHOT + 431-SNAPSHOT diff --git a/testing/trino-test-jdbc-compatibility-old-server/pom.xml b/testing/trino-test-jdbc-compatibility-old-server/pom.xml index e50483e9a49d..5bc89b0fcdf5 100644 --- a/testing/trino-test-jdbc-compatibility-old-server/pom.xml +++ b/testing/trino-test-jdbc-compatibility-old-server/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-containers/pom.xml b/testing/trino-testing-containers/pom.xml index 941130c8ba59..70fc7fe143b0 100644 --- a/testing/trino-testing-containers/pom.xml +++ b/testing/trino-testing-containers/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-kafka/pom.xml b/testing/trino-testing-kafka/pom.xml index b6a4757ef9bb..ebdd1e1b0eb5 100644 --- a/testing/trino-testing-kafka/pom.xml +++ b/testing/trino-testing-kafka/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-resources/pom.xml b/testing/trino-testing-resources/pom.xml index 03b0901141a7..8744094f9b9b 100644 --- a/testing/trino-testing-resources/pom.xml +++ b/testing/trino-testing-resources/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing-services/pom.xml b/testing/trino-testing-services/pom.xml index 5b564004783b..d5db3e92cd10 100644 --- a/testing/trino-testing-services/pom.xml +++ b/testing/trino-testing-services/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml index fb0f794e1701..acff2c52e36c 100644 --- a/testing/trino-testing/pom.xml +++ b/testing/trino-testing/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 26be645eeba9..5cb4c66e35d3 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -71,7 +71,6 @@ import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.tests.QueryTemplate.parameter; import static io.trino.tests.QueryTemplate.queryTemplate; -import static io.trino.type.UnknownType.UNKNOWN; import static java.lang.String.format; import static java.util.Collections.nCopies; import static java.util.stream.Collectors.joining; @@ -1382,7 +1381,7 @@ public void testDescribeInputNoParameters() .addPreparedStatement("my_query", "SELECT * FROM nation") .build(); assertThat(query(session, "DESCRIBE INPUT my_query")) - .hasOutputTypes(List.of(UNKNOWN, UNKNOWN)) + .hasOutputTypes(List.of(BIGINT, VARCHAR)) .returnsEmptyResult(); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseComplexTypesPredicatePushDownTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseComplexTypesPredicatePushDownTest.java index c7640eab85d2..155b7facd06c 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseComplexTypesPredicatePushDownTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseComplexTypesPredicatePushDownTest.java @@ -13,7 +13,7 @@ */ package io.trino.testing; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.testing.TestingNames.randomNameSuffix; import static org.assertj.core.api.Assertions.assertThat; diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java index 7cdfeda5270b..1edd986c17f7 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java @@ -27,14 +27,14 @@ import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.tpch.TpchTable; import org.intellij.lang.annotations.Language; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Stream; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.SystemSessionProperties.ENABLE_LARGE_DYNAMIC_FILTERS; @@ -46,7 +46,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED; import static io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy.NONE; -import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.tpch.TpchTable.LINE_ITEM; import static io.trino.tpch.TpchTable.ORDERS; @@ -54,9 +53,11 @@ import static io.trino.util.DynamicFiltersTestUtil.getSimplifiedDomainString; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; +@TestInstance(PER_CLASS) public abstract class BaseDynamicPartitionPruningTest extends AbstractTestQueryFramework { @@ -70,7 +71,7 @@ public abstract class BaseDynamicPartitionPruningTest // disable semi join to inner join rewrite to test semi join operators explicitly "optimizer.rewrite-filtering-semi-join-to-inner-join", "false"); - @BeforeClass + @BeforeAll public void initTables() throws Exception { @@ -95,7 +96,8 @@ protected Session getSession() .build(); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinWithEmptyBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem JOIN supplier ON partitioned_lineitem.suppkey = supplier.suppkey AND supplier.name = 'abc'"; @@ -116,7 +118,8 @@ public void testJoinWithEmptyBuildSide() assertTrue(domainStats.getCollectionDuration().isPresent()); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinWithSelectiveBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem JOIN supplier ON partitioned_lineitem.suppkey = supplier.suppkey " + @@ -137,7 +140,8 @@ public void testJoinWithSelectiveBuildSide() assertEquals(domainStats.getSimplifiedDomain(), singleValue(BIGINT, 1L).toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinWithNonSelectiveBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem JOIN supplier ON partitioned_lineitem.suppkey = supplier.suppkey"; @@ -158,7 +162,8 @@ public void testJoinWithNonSelectiveBuildSide() .isEqualTo(getSimplifiedDomainString(1L, 100L, 100, BIGINT)); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinLargeBuildSideRangeDynamicFiltering() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem JOIN orders ON partitioned_lineitem.orderkey = orders.orderkey"; @@ -181,7 +186,8 @@ public void testJoinLargeBuildSideRangeDynamicFiltering() .toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinWithMultipleDynamicFiltersOnProbe() { // supplier names Supplier#000000001 and Supplier#000000002 match suppkey 1 and 2 @@ -208,7 +214,8 @@ public void testJoinWithMultipleDynamicFiltersOnProbe() getSimplifiedDomainString(2L, 2L, 1, BIGINT)); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinWithImplicitCoercion() { // setup partitioned fact table with integer suppkey @@ -237,7 +244,8 @@ public void testJoinWithImplicitCoercion() assertEquals(domainStats.getSimplifiedDomain(), singleValue(BIGINT, 1L).toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testSemiJoinWithEmptyBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem WHERE suppkey IN (SELECT suppkey FROM supplier WHERE name = 'abc')"; @@ -257,7 +265,8 @@ public void testSemiJoinWithEmptyBuildSide() assertEquals(domainStats.getSimplifiedDomain(), none(BIGINT).toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testSemiJoinWithSelectiveBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem WHERE suppkey IN (SELECT suppkey FROM supplier WHERE name = 'Supplier#000000001')"; @@ -277,7 +286,8 @@ public void testSemiJoinWithSelectiveBuildSide() assertEquals(domainStats.getSimplifiedDomain(), singleValue(BIGINT, 1L).toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testSemiJoinWithNonSelectiveBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem WHERE suppkey IN (SELECT suppkey FROM supplier)"; @@ -298,7 +308,8 @@ public void testSemiJoinWithNonSelectiveBuildSide() .isEqualTo(getSimplifiedDomainString(1L, 100L, 100, BIGINT)); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testSemiJoinLargeBuildSideRangeDynamicFiltering() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem WHERE orderkey IN (SELECT orderkey FROM orders)"; @@ -321,7 +332,8 @@ public void testSemiJoinLargeBuildSideRangeDynamicFiltering() .toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithEmptyBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem l RIGHT JOIN supplier s ON l.suppkey = s.suppkey WHERE name = 'abc'"; @@ -341,7 +353,8 @@ public void testRightJoinWithEmptyBuildSide() assertEquals(domainStats.getSimplifiedDomain(), none(BIGINT).toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithSelectiveBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem l RIGHT JOIN supplier s ON l.suppkey = s.suppkey WHERE name = 'Supplier#000000001'"; @@ -361,7 +374,8 @@ public void testRightJoinWithSelectiveBuildSide() assertEquals(domainStats.getSimplifiedDomain(), singleValue(BIGINT, 1L).toString(getSession().toConnectorSession())); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithNonSelectiveBuildSide() { @Language("SQL") String selectQuery = "SELECT * FROM partitioned_lineitem l RIGHT JOIN supplier s ON l.suppkey = s.suppkey"; @@ -382,23 +396,27 @@ public void testRightJoinWithNonSelectiveBuildSide() .isEqualTo(getSimplifiedDomainString(1L, 100L, 100, BIGINT)); } - @Test(timeOut = 30_000, dataProvider = "joinDistributionTypes") - public void testJoinDynamicFilteringMultiJoinOnPartitionedTables(JoinDistributionType joinDistributionType) + @Test + @Timeout(30) + public void testJoinDynamicFilteringMultiJoinOnPartitionedTables() { - assertUpdate("DROP TABLE IF EXISTS t0_part"); - assertUpdate("DROP TABLE IF EXISTS t1_part"); - assertUpdate("DROP TABLE IF EXISTS t2_part"); - createPartitionedTable("t0_part", ImmutableList.of("v0 real", "k0 integer"), ImmutableList.of("k0")); - createPartitionedTable("t1_part", ImmutableList.of("v1 real", "i1 integer"), ImmutableList.of()); - createPartitionedTable("t2_part", ImmutableList.of("v2 real", "i2 integer", "k2 integer"), ImmutableList.of("k2")); - assertUpdate("INSERT INTO t0_part VALUES (1.0, 1), (1.0, 2)", 2); - assertUpdate("INSERT INTO t1_part VALUES (2.0, 10), (2.0, 20)", 2); - assertUpdate("INSERT INTO t2_part VALUES (3.0, 1, 1), (3.0, 2, 2)", 2); - testJoinDynamicFilteringMultiJoin(joinDistributionType, "t0_part", "t1_part", "t2_part"); + for (JoinDistributionType joinDistributionType : JoinDistributionType.values()) { + assertUpdate("DROP TABLE IF EXISTS t0_part"); + assertUpdate("DROP TABLE IF EXISTS t1_part"); + assertUpdate("DROP TABLE IF EXISTS t2_part"); + createPartitionedTable("t0_part", ImmutableList.of("v0 real", "k0 integer"), ImmutableList.of("k0")); + createPartitionedTable("t1_part", ImmutableList.of("v1 real", "i1 integer"), ImmutableList.of()); + createPartitionedTable("t2_part", ImmutableList.of("v2 real", "i2 integer", "k2 integer"), ImmutableList.of("k2")); + assertUpdate("INSERT INTO t0_part VALUES (1.0, 1), (1.0, 2)", 2); + assertUpdate("INSERT INTO t1_part VALUES (2.0, 10), (2.0, 20)", 2); + assertUpdate("INSERT INTO t2_part VALUES (3.0, 1, 1), (3.0, 2, 2)", 2); + testJoinDynamicFilteringMultiJoin(joinDistributionType, "t0_part", "t1_part", "t2_part"); + } } // TODO: use joinDistributionTypeProvider when https://github.com/trinodb/trino/issues/4713 is done as currently waiting for BROADCAST DFs doesn't work for bucketed tables - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinDynamicFilteringMultiJoinOnBucketedTables() { assertUpdate("DROP TABLE IF EXISTS t0_bucketed"); @@ -457,13 +475,6 @@ private long getQueryInputPositions(Session session, @Language("SQL") String sql return stats.getPhysicalInputPositions(); } - @DataProvider - public Object[][] joinDistributionTypes() - { - return Stream.of(JoinDistributionType.values()) - .collect(toDataProvider()); - } - private Session withDynamicFilteringDisabled() { return withDynamicFilteringDisabled(getSession()); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseOrcWithBloomFiltersTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseOrcWithBloomFiltersTest.java index 44d2c64cabd1..8ce425b4fb5d 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseOrcWithBloomFiltersTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseOrcWithBloomFiltersTest.java @@ -15,7 +15,7 @@ import io.trino.Session; import org.intellij.lang.annotations.Language; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetWithBloomFilters.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetWithBloomFilters.java index 76eb9f57f0b3..4cd93f25296f 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetWithBloomFilters.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetWithBloomFilters.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableSet; import io.trino.Session; import io.trino.spi.connector.CatalogSchemaTableName; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.nio.file.Path; import java.util.Arrays; diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index ace4b0c3c822..541303c62ac6 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -5,7 +5,7 @@ io.trino trino-root - 430-SNAPSHOT + 431-SNAPSHOT ../../pom.xml diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java index 1769da31ffc5..825fc08f3587 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java @@ -68,6 +68,7 @@ import static io.trino.SystemSessionProperties.MAX_HASH_PARTITION_COUNT; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY; import static io.trino.client.ClientCapabilities.PATH; +import static io.trino.client.ClientCapabilities.SESSION_AUTHORIZATION; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; import static io.trino.spi.StandardErrorCode.INCOMPATIBLE_CLIENT; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -291,6 +292,34 @@ public void testSetPathSupportByClient() } } + @Test + public void testSetSessionSupportByClient() + { + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of()).build())) { + assertThatThrownBy(() -> testingClient.execute("SET SESSION AUTHORIZATION userA")) + .hasMessage("SET SESSION AUTHORIZATION not supported by client"); + } + + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of( + SESSION_AUTHORIZATION.name())).build())) { + testingClient.execute("SET SESSION AUTHORIZATION userA"); + } + } + + @Test + public void testResetSessionSupportByClient() + { + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of()).build())) { + assertThatThrownBy(() -> testingClient.execute("RESET SESSION AUTHORIZATION")) + .hasMessage("RESET SESSION AUTHORIZATION not supported by client"); + } + + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of( + SESSION_AUTHORIZATION.name())).build())) { + testingClient.execute("RESET SESSION AUTHORIZATION"); + } + } + private void checkVersionOnError(String query, @Language("RegExp") String proofOfOrigin) { QueryResults queryResults = postQuery(request -> request diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q09.plan.txt index 4466a2d88feb..cbcd78f04da6 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q54.plan.txt index a38e81074509..d524385f72f4 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) partial aggregation over (ss_customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): dynamic filter (["ss_customer_sk", "ss_sold_date_sk"]) @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, []) scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_26"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, []) diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q09.plan.txt index 4466a2d88feb..cbcd78f04da6 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q54.plan.txt index 6b99ed23c254..623a7d64ea2c 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/unpartitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) partial aggregation over (ss_customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): dynamic filter (["ss_customer_sk", "ss_sold_date_sk"]) @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, []) scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_24"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, []) diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q09.plan.txt index 4466a2d88feb..cbcd78f04da6 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q54.plan.txt index 5806dc346635..851fa255cf39 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/partitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) partial aggregation over (ss_customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): dynamic filter (["ss_customer_sk", "ss_sold_date_sk"]) @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, []) scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_17"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, []) diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q09.plan.txt index 4466a2d88feb..cbcd78f04da6 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q54.plan.txt index 5806dc346635..851fa255cf39 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/orc/unpartitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) partial aggregation over (ss_customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): dynamic filter (["ss_customer_sk", "ss_sold_date_sk"]) @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, []) scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_17"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, []) diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q09.plan.txt index 4466a2d88feb..cbcd78f04da6 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q54.plan.txt index 5806dc346635..851fa255cf39 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/partitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) partial aggregation over (ss_customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): dynamic filter (["ss_customer_sk", "ss_sold_date_sk"]) @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, []) scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_17"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, []) diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q09.plan.txt index 5660bb85f04e..cbcd78f04da6 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - cross join: - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q54.plan.txt index 5806dc346635..851fa255cf39 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg/parquet/unpartitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) partial aggregation over (ss_customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): dynamic filter (["ss_customer_sk", "ss_sold_date_sk"]) @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, []) scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_17"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, []) diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q09.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q09.plan.txt index 4466a2d88feb..510c370f94d9 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q09.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q09.plan.txt @@ -1,93 +1,122 @@ -cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: - cross join: +join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, []) + scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + final aggregation over () + local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + partial aggregation over () + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - scan reason + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) final aggregation over () local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) partial aggregation over () scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales - final aggregation over () - local exchange (GATHER, SINGLE, []) - remote exchange (GATHER, SINGLE, []) - partial aggregation over () - scan store_sales diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q54.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q54.plan.txt index f578993e3014..b49678e02521 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q54.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/iceberg_small_files/parquet/unpartitioned/q54.plan.txt @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["customer_sk"]) partial aggregation over (customer_sk) - cross join: - cross join: + join (LEFT, REPLICATED): + join (LEFT, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): join (INNER, PARTITIONED): @@ -50,8 +50,7 @@ local exchange (GATHER, SINGLE, []) scan store local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - dynamic filter (["d_month_seq_17"]) - scan date_dim + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) local exchange (GATHER, SINGLE, [])