> tokenSource)
+ {
+ requireNonNull(tokenSource, "tokenSource is null");
+
+ knownToken = tokenSource.get();
+ }
+}
diff --git a/client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java b/client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java
new file mode 100644
index 000000000000..e8513e4bd87f
--- /dev/null
+++ b/client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.client.auth.external;
+
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.util.Optional;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.function.Supplier;
+
+/**
+ * This KnownToken instance forces all Connections to reuse same token.
+ * Every time an existing token is considered to be invalid each Connection
+ * will try to obtain a new token, but only the first one will actually do the job,
+ * where every other connection will be waiting on readLock
+ * until obtaining new token finishes.
+ *
+ * In general the game is to reuse same token and obtain it only once, no matter how
+ * many Connections will be actively using it. It's very important as obtaining the new token
+ * will take minutes, as it mostly requires user thinking time.
+ */
+@ThreadSafe
+class MemoryCachedKnownToken
+ implements KnownToken
+{
+ public static final MemoryCachedKnownToken INSTANCE = new MemoryCachedKnownToken();
+
+ private final ReadWriteLock lock = new ReentrantReadWriteLock();
+ private final Lock readLock = lock.readLock();
+ private final Lock writeLock = lock.writeLock();
+ private Optional knownToken = Optional.empty();
+
+ private MemoryCachedKnownToken()
+ {
+ }
+
+ @Override
+ public Optional getToken()
+ {
+ try {
+ readLock.lockInterruptibly();
+ return knownToken;
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ finally {
+ readLock.unlock();
+ }
+ }
+
+ @Override
+ public void setupToken(Supplier> tokenSource)
+ {
+ // Try to lock and generate new token. If some other thread (Connection) has
+ // already obtained writeLock and is generating new token, then skipp this
+ // to block on getToken()
+ if (writeLock.tryLock()) {
+ try {
+ // Clear knownToken before obtaining new token, as it might fail leaving old invalid token.
+ knownToken = Optional.empty();
+ knownToken = tokenSource.get();
+ }
+ finally {
+ writeLock.unlock();
+ }
+ }
+ }
+}
diff --git a/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java b/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java
index ca74648d4c0d..66390b162f0e 100644
--- a/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java
+++ b/client/trino-client/src/test/java/io/trino/client/TestServerInfo.java
@@ -21,6 +21,7 @@
import static io.airlift.json.JsonCodec.jsonCodec;
import static io.trino.client.NodeVersion.UNKNOWN;
+import static java.util.concurrent.TimeUnit.MINUTES;
import static org.testng.Assert.assertEquals;
public class TestServerInfo
@@ -30,7 +31,7 @@ public class TestServerInfo
@Test
public void testJsonRoundTrip()
{
- assertJsonRoundTrip(new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m"))));
+ assertJsonRoundTrip(new ServerInfo(UNKNOWN, "test", true, false, Optional.of(new Duration(2, MINUTES))));
assertJsonRoundTrip(new ServerInfo(UNKNOWN, "test", true, false, Optional.empty()));
}
diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java b/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java
new file mode 100644
index 000000000000..cf671ea391c9
--- /dev/null
+++ b/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.client.auth.external;
+
+import java.net.URI;
+import java.time.Duration;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class MockRedirectHandler
+ implements RedirectHandler
+{
+ private URI redirectedTo;
+ private AtomicInteger redirectionCount = new AtomicInteger(0);
+ private Duration redirectTime;
+
+ @Override
+ public void redirectTo(URI uri)
+ throws RedirectException
+ {
+ redirectedTo = uri;
+ redirectionCount.incrementAndGet();
+ try {
+ if (redirectTime != null) {
+ Thread.sleep(redirectTime.toMillis());
+ }
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ }
+
+ public URI redirectedTo()
+ {
+ return redirectedTo;
+ }
+
+ public int getRedirectionCount()
+ {
+ return redirectionCount.get();
+ }
+
+ public MockRedirectHandler sleepOnRedirect(Duration redirectTime)
+ {
+ this.redirectTime = redirectTime;
+ return this;
+ }
+}
diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java b/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java
index fae34ea9c031..d07205f38636 100644
--- a/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java
+++ b/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java
@@ -17,21 +17,21 @@
import java.net.URI;
import java.time.Duration;
-import java.util.ArrayDeque;
-import java.util.HashMap;
import java.util.Map;
-import java.util.Queue;
+import java.util.concurrent.BlockingDeque;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingDeque;
public final class MockTokenPoller
implements TokenPoller
{
- private final Map> results = new HashMap<>();
+ private final Map> results = new ConcurrentHashMap<>();
public MockTokenPoller withResult(URI tokenUri, TokenPollResult result)
{
results.compute(tokenUri, (uri, queue) -> {
if (queue == null) {
- return new ArrayDeque<>(ImmutableList.of(result));
+ return new LinkedBlockingDeque<>(ImmutableList.of(result));
}
queue.add(result);
return queue;
@@ -42,7 +42,7 @@ public MockTokenPoller withResult(URI tokenUri, TokenPollResult result)
@Override
public TokenPollResult pollForToken(URI tokenUri, Duration ignored)
{
- Queue queue = results.get(tokenUri);
+ BlockingDeque queue = results.get(tokenUri);
if (queue == null) {
throw new IllegalArgumentException("Unknown token URI: " + tokenUri);
}
diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java
index 279d1966b4c7..dce341d20392 100644
--- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java
+++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java
@@ -125,22 +125,4 @@ public void testObtainTokenWhenNoRedirectUriHasBeenProvided()
assertThat(redirectHandler.redirectedTo()).isNull();
assertThat(token).map(Token::token).hasValue(AUTH_TOKEN);
}
-
- private static class MockRedirectHandler
- implements RedirectHandler
- {
- private URI redirectedTo;
-
- @Override
- public void redirectTo(URI uri)
- throws RedirectException
- {
- redirectedTo = uri;
- }
-
- public URI redirectedTo()
- {
- return redirectedTo;
- }
- }
}
diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java
index 729ffe6fa023..c9ea6fc73199 100644
--- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java
+++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java
@@ -13,31 +13,55 @@
*/
package io.trino.client.auth.external;
+import com.google.common.collect.ImmutableList;
import io.trino.client.ClientException;
import okhttp3.HttpUrl;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.Response;
+import org.assertj.core.api.ListAssert;
+import org.assertj.core.api.ThrowableAssert;
+import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Optional;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.stream.Stream;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE;
+import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.trino.client.auth.external.ExternalAuthenticator.TOKEN_URI_FIELD;
import static io.trino.client.auth.external.ExternalAuthenticator.toAuthentication;
import static io.trino.client.auth.external.TokenPollResult.successful;
import static java.lang.String.format;
import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED;
import static java.net.URI.create;
+import static java.util.concurrent.Executors.newCachedThreadPool;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class TestExternalAuthenticator
{
+ private static final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(TestExternalAuthenticator.class.getName() + "-%d"));
+
+ @AfterClass(alwaysRun = true)
+ public void shutDownThreadPool()
+ {
+ executor.shutdownNow();
+ }
+
@Test
public void testChallengeWithOnlyTokenServerUri()
{
@@ -110,7 +134,7 @@ public void testAuthentication()
{
MockTokenPoller tokenPoller = new MockTokenPoller()
.withResult(URI.create("http://token.uri"), successful(new Token("valid-token")));
- ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, Duration.ofSeconds(1));
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, KnownToken.local(), Duration.ofSeconds(1));
Request authenticated = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\""));
@@ -125,7 +149,7 @@ public void testReAuthenticationAfterRejectingToken()
MockTokenPoller tokenPoller = new MockTokenPoller()
.withResult(URI.create("http://token.uri"), successful(new Token("first-token")))
.withResult(URI.create("http://token.uri"), successful(new Token("second-token")));
- ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, Duration.ofSeconds(1));
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, KnownToken.local(), Duration.ofSeconds(1));
Request request = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\""));
Request reAuthenticated = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"", request));
@@ -134,6 +158,140 @@ public void testReAuthenticationAfterRejectingToken()
.containsExactly("Bearer second-token");
}
+ @Test(timeOut = 2000)
+ public void testAuthenticationFromMultipleThreadsWithLocallyStoredToken()
+ {
+ MockTokenPoller tokenPoller = new MockTokenPoller()
+ .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-1")))
+ .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-2")))
+ .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-3")))
+ .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-4")));
+ MockRedirectHandler redirectHandler = new MockRedirectHandler();
+
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.local(), Duration.ofSeconds(1));
+ List> requests = times(
+ 4,
+ () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")))
+ .map(executor::submit)
+ .collect(toImmutableList());
+
+ ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests);
+ assertion.requests()
+ .extracting(Request::headers)
+ .extracting(headers -> headers.get(AUTHORIZATION))
+ .contains("Bearer valid-token-1", "Bearer valid-token-2", "Bearer valid-token-3", "Bearer valid-token-4");
+ assertion.assertThatNoExceptionsHasBeenThrown();
+ assertThat(redirectHandler.getRedirectionCount()).isEqualTo(4);
+ }
+
+ @Test(timeOut = 2000)
+ public void testAuthenticationFromMultipleThreadsWithCachedToken()
+ {
+ ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(this.getClass().getName() + "%n"));
+ MockTokenPoller tokenPoller = new MockTokenPoller()
+ .withResult(URI.create("http://token.uri"), successful(new Token("valid-token")));
+ MockRedirectHandler redirectHandler = new MockRedirectHandler()
+ .sleepOnRedirect(Duration.ofMillis(10));
+
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1));
+ List> requests = times(
+ 4,
+ () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")))
+ .map(executor::submit)
+ .collect(toImmutableList());
+
+ ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests);
+ assertion.requests()
+ .extracting(Request::headers)
+ .extracting(headers -> headers.get(AUTHORIZATION))
+ .containsOnly("Bearer valid-token");
+ assertion.assertThatNoExceptionsHasBeenThrown();
+ assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1);
+ }
+
+ @Test(timeOut = 2000)
+ public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateFails()
+ {
+ MockTokenPoller tokenPoller = new MockTokenPoller()
+ .withResult(URI.create("http://token.uri"), TokenPollResult.successful(new Token("first-token")))
+ .withResult(URI.create("http://token.uri"), TokenPollResult.failed("external authentication error"));
+ MockRedirectHandler redirectHandler = new MockRedirectHandler()
+ .sleepOnRedirect(Duration.ofMillis(10));
+
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1));
+ Request firstRequest = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""));
+
+ List> requests = times(
+ 4,
+ () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"", firstRequest)))
+ .map(executor::submit)
+ .collect(toImmutableList());
+
+ ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests);
+ assertion.requests().containsExactly(null, null, null);
+ assertion.firstException().hasMessage("external authentication error")
+ .isInstanceOf(ClientException.class);
+
+ assertThat(redirectHandler.getRedirectionCount()).isEqualTo(2);
+ }
+
+ @Test(timeOut = 2000)
+ public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateTimesOut()
+ {
+ MockRedirectHandler redirectHandler = new MockRedirectHandler()
+ .sleepOnRedirect(Duration.ofMillis(5));
+
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, (uri, duration) -> TokenPollResult.pending(uri), KnownToken.memoryCached(), Duration.ofMillis(1));
+ List> requests = times(
+ 4,
+ () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")))
+ .map(executor::submit)
+ .collect(toImmutableList());
+
+ ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests);
+ assertion.requests()
+ .containsExactly(null, null, null, null);
+ assertion.assertThatNoExceptionsHasBeenThrown();
+ assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1);
+ }
+
+ @Test(timeOut = 2000)
+ public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateIsInterrupted()
+ throws Exception
+ {
+ ExecutorService interruptableThreadPool = newCachedThreadPool(daemonThreadsNamed(this.getClass().getName() + "-interruptable-%d"));
+ MockRedirectHandler redirectHandler = new MockRedirectHandler()
+ .sleepOnRedirect(Duration.ofMinutes(1));
+
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, (uri, duration) -> TokenPollResult.pending(uri), KnownToken.memoryCached(), Duration.ofMillis(1));
+ Future interruptedAuthentication = interruptableThreadPool.submit(
+ () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")));
+ Thread.sleep(100); //It's here to make sure that authentication will start before the other threads.
+ List> requests = times(
+ 2,
+ () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")))
+ .map(executor::submit)
+ .collect(toImmutableList());
+
+ Thread.sleep(100);
+ interruptableThreadPool.shutdownNow();
+
+ ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(ImmutableList.>builder()
+ .addAll(requests)
+ .add(interruptedAuthentication)
+ .build());
+ assertion.requests().containsExactly(null, null);
+ assertion.firstException().hasRootCauseInstanceOf(InterruptedException.class);
+
+ assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1);
+ }
+
+ private static Stream> times(int times, Callable request)
+ {
+ return Stream.generate(() -> request)
+ .limit(times);
+ }
+
private static Optional buildAuthentication(String challengeHeader)
{
return toAuthentication(getUnauthorizedResponse(challengeHeader));
@@ -157,4 +315,50 @@ private static Response getUnauthorizedResponse(String challengeHeader, Request
.header(WWW_AUTHENTICATE, challengeHeader)
.build();
}
+
+ static class ConcurrentRequestAssertion
+ {
+ private final List exceptions = new ArrayList<>();
+ private final List requests = new ArrayList<>();
+
+ public ConcurrentRequestAssertion(List> requests)
+ {
+ for (Future request : requests) {
+ try {
+ this.requests.add(request.get());
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ catch (CancellationException ex) {
+ exceptions.add(ex);
+ }
+ catch (ExecutionException ex) {
+ checkState(ex.getCause() != null, "Missing cause on ExecutionException " + ex.getMessage());
+
+ exceptions.add(ex.getCause());
+ }
+ }
+ }
+
+ ThrowableAssert firstException()
+ {
+ return exceptions.stream()
+ .findFirst()
+ .map(ThrowableAssert::new)
+ .orElseGet(() -> new ThrowableAssert(() -> null));
+ }
+
+ void assertThatNoExceptionsHasBeenThrown()
+ {
+ assertThat(exceptions)
+ .isEmpty();
+ }
+
+ ListAssert requests()
+ {
+ return assertThat(requests);
+ }
+ }
}
diff --git a/client/trino-jdbc/pom.xml b/client/trino-jdbc/pom.xml
index 74692b96692c..2fd3bbb69af6 100644
--- a/client/trino-jdbc/pom.xml
+++ b/client/trino-jdbc/pom.xml
@@ -5,7 +5,7 @@
io.trino
trino-root
- 355-SNAPSHOT
+ 356-SNAPSHOT
../../pom.xml
diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java
index c3db9662374b..5618e47bd1b4 100644
--- a/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java
+++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java
@@ -55,6 +55,7 @@ enum SslVerificationMode
public static final ConnectionProperty HTTP_PROXY = new HttpProxy();
public static final ConnectionProperty APPLICATION_NAME_PREFIX = new ApplicationNamePrefix();
public static final ConnectionProperty DISABLE_COMPRESSION = new DisableCompression();
+ public static final ConnectionProperty ASSUME_LITERAL_NAMES_IN_METADATA_CALLS_FOR_NON_CONFORMING_CLIENTS = new AssumeLiteralNamesInMetadataCallsForNonConformingClients();
public static final ConnectionProperty SSL = new Ssl();
public static final ConnectionProperty SSL_VERIFICATION = new SslVerification();
public static final ConnectionProperty SSL_KEY_STORE_PATH = new SslKeyStorePath();
@@ -73,6 +74,7 @@ enum SslVerificationMode
public static final ConnectionProperty ACCESS_TOKEN = new AccessToken();
public static final ConnectionProperty EXTERNAL_AUTHENTICATION = new ExternalAuthentication();
public static final ConnectionProperty EXTERNAL_AUTHENTICATION_TIMEOUT = new ExternalAuthenticationTimeout();
+ public static final ConnectionProperty EXTERNAL_AUTHENTICATION_TOKEN_CACHE = new ExternalAuthenticationTokenCache();
public static final ConnectionProperty> EXTRA_CREDENTIALS = new ExtraCredentials();
public static final ConnectionProperty CLIENT_INFO = new ClientInfo();
public static final ConnectionProperty CLIENT_TAGS = new ClientTags();
@@ -89,6 +91,7 @@ enum SslVerificationMode
.add(HTTP_PROXY)
.add(APPLICATION_NAME_PREFIX)
.add(DISABLE_COMPRESSION)
+ .add(ASSUME_LITERAL_NAMES_IN_METADATA_CALLS_FOR_NON_CONFORMING_CLIENTS)
.add(SSL)
.add(SSL_VERIFICATION)
.add(SSL_KEY_STORE_PATH)
@@ -113,6 +116,7 @@ enum SslVerificationMode
.add(SOURCE)
.add(EXTERNAL_AUTHENTICATION)
.add(EXTERNAL_AUTHENTICATION_TIMEOUT)
+ .add(EXTERNAL_AUTHENTICATION_TOKEN_CACHE)
.build();
private static final Map> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream()
@@ -273,6 +277,15 @@ public DisableCompression()
}
}
+ private static class AssumeLiteralNamesInMetadataCallsForNonConformingClients
+ extends AbstractConnectionProperty
+ {
+ public AssumeLiteralNamesInMetadataCallsForNonConformingClients()
+ {
+ super("assumeLiteralNamesInMetadataCallsForNonConformingClients", NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER);
+ }
+ }
+
private static class Ssl
extends AbstractConnectionProperty
{
@@ -461,6 +474,15 @@ public ExternalAuthenticationTimeout()
}
}
+ private static class ExternalAuthenticationTokenCache
+ extends AbstractConnectionProperty
+ {
+ public ExternalAuthenticationTokenCache()
+ {
+ super("externalAuthenticationTokenCache", Optional.of(KnownTokenCache.NONE.name()), NOT_REQUIRED, ALLOWED, KnownTokenCache::valueOf);
+ }
+ }
+
private static class ExtraCredentials
extends AbstractConnectionProperty>
{
diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java
new file mode 100644
index 000000000000..6c3dde57d8c7
--- /dev/null
+++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.jdbc;
+
+import io.trino.client.auth.external.KnownToken;
+
+public enum KnownTokenCache
+{
+ NONE {
+ @Override
+ KnownToken create()
+ {
+ return KnownToken.local();
+ }
+ },
+ MEMORY {
+ @Override
+ KnownToken create()
+ {
+ return KnownToken.memoryCached();
+ }
+ };
+
+ abstract KnownToken create();
+}
diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/NonRegisteringTrinoDriver.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/NonRegisteringTrinoDriver.java
index cb63d2a18ac9..f529737efe2f 100644
--- a/client/trino-jdbc/src/main/java/io/trino/jdbc/NonRegisteringTrinoDriver.java
+++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/NonRegisteringTrinoDriver.java
@@ -24,7 +24,6 @@
import java.util.Properties;
import java.util.logging.Logger;
-import static io.trino.client.OkHttpUtil.setupChannelSocket;
import static io.trino.client.OkHttpUtil.userAgent;
import static io.trino.jdbc.DriverInfo.DRIVER_NAME;
import static io.trino.jdbc.DriverInfo.DRIVER_VERSION;
@@ -112,9 +111,6 @@ private static OkHttpClient newHttpClient()
{
OkHttpClient.Builder builder = new OkHttpClient.Builder()
.addInterceptor(userAgent(DRIVER_NAME + "/" + DRIVER_VERSION));
-
- // Enable socket factory only for pre JDK 11
- setupChannelSocket(builder);
return builder.build();
}
}
diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
index 81c2b8cc148a..bea84728d960 100644
--- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
+++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
@@ -91,6 +91,7 @@ public class TrinoConnection
private final String user;
private final Optional sessionUser;
private final boolean compressionDisabled;
+ private final boolean assumeLiteralNamesInMetadataCallsForNonConformingClients;
private final Map extraCredentials;
private final Optional applicationNamePrefix;
private final Optional source;
@@ -115,6 +116,7 @@ public class TrinoConnection
this.source = uri.getSource();
this.extraCredentials = uri.getExtraCredentials();
this.compressionDisabled = uri.isCompressionDisabled();
+ this.assumeLiteralNamesInMetadataCallsForNonConformingClients = uri.isAssumeLiteralNamesInMetadataCallsForNonConformingClients();
this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null");
uri.getClientInfo().ifPresent(tags -> clientInfo.put(CLIENT_INFO, tags));
uri.getClientTags().ifPresent(tags -> clientInfo.put(CLIENT_TAGS, tags));
@@ -238,7 +240,7 @@ public boolean isClosed()
public DatabaseMetaData getMetaData()
throws SQLException
{
- return new TrinoDatabaseMetaData(this);
+ return new TrinoDatabaseMetaData(this, assumeLiteralNamesInMetadataCallsForNonConformingClients);
}
@Override
diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java
index 49033d3ca876..9782234da699 100644
--- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java
+++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java
@@ -20,6 +20,8 @@
import io.trino.client.ClientTypeSignatureParameter;
import io.trino.client.Column;
+import javax.annotation.Nullable;
+
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
@@ -31,6 +33,7 @@
import java.util.List;
import java.util.stream.Stream;
+import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Lists.newArrayList;
import static io.trino.client.ClientTypeSignature.VARCHAR_UNBOUNDED_LENGTH;
import static io.trino.jdbc.DriverInfo.DRIVER_NAME;
@@ -46,10 +49,12 @@ public class TrinoDatabaseMetaData
private static final String SEARCH_STRING_ESCAPE = "\\";
private final TrinoConnection connection;
+ private final boolean assumeLiteralNamesInMetadataCallsForNonConformingClients;
- TrinoDatabaseMetaData(TrinoConnection connection)
+ TrinoDatabaseMetaData(TrinoConnection connection, boolean assumeLiteralNamesInMetadataCallsForNonConformingClients)
{
this.connection = requireNonNull(connection, "connection is null");
+ this.assumeLiteralNamesInMetadataCallsForNonConformingClients = assumeLiteralNamesInMetadataCallsForNonConformingClients;
}
@Override
@@ -897,6 +902,8 @@ public boolean dataDefinitionIgnoredInTransactions()
public ResultSet getProcedures(String catalog, String schemaPattern, String procedureNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ procedureNamePattern = escapeIfNecessary(procedureNamePattern);
return selectEmpty("" +
"SELECT PROCEDURE_CAT, PROCEDURE_SCHEM, PROCEDURE_NAME,\n " +
" null, null, null, REMARKS, PROCEDURE_TYPE, SPECIFIC_NAME\n" +
@@ -908,6 +915,9 @@ public ResultSet getProcedures(String catalog, String schemaPattern, String proc
public ResultSet getProcedureColumns(String catalog, String schemaPattern, String procedureNamePattern, String columnNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ procedureNamePattern = escapeIfNecessary(procedureNamePattern);
+ columnNamePattern = escapeIfNecessary(columnNamePattern);
return selectEmpty("" +
"SELECT PROCEDURE_CAT, PROCEDURE_SCHEM, PROCEDURE_NAME, " +
" COLUMN_NAME, COLUMN_TYPE, DATA_TYPE, TYPE_NAME,\n" +
@@ -922,6 +932,8 @@ public ResultSet getProcedureColumns(String catalog, String schemaPattern, Strin
public ResultSet getTables(String catalog, String schemaPattern, String tableNamePattern, String[] types)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ tableNamePattern = escapeIfNecessary(tableNamePattern);
StringBuilder query = new StringBuilder("" +
"SELECT TABLE_CAT, TABLE_SCHEM, TABLE_NAME, TABLE_TYPE, REMARKS,\n" +
" TYPE_CAT, TYPE_SCHEM, TYPE_NAME, " +
@@ -974,6 +986,9 @@ public ResultSet getTableTypes()
public ResultSet getColumns(String catalog, String schemaPattern, String tableNamePattern, String columnNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ tableNamePattern = escapeIfNecessary(tableNamePattern);
+ columnNamePattern = escapeIfNecessary(columnNamePattern);
StringBuilder query = new StringBuilder("" +
"SELECT TABLE_CAT, TABLE_SCHEM, TABLE_NAME, COLUMN_NAME, DATA_TYPE,\n" +
" TYPE_NAME, COLUMN_SIZE, BUFFER_LENGTH, DECIMAL_DIGITS, NUM_PREC_RADIX,\n" +
@@ -999,6 +1014,7 @@ public ResultSet getColumns(String catalog, String schemaPattern, String tableNa
public ResultSet getColumnPrivileges(String catalog, String schema, String table, String columnNamePattern)
throws SQLException
{
+ columnNamePattern = escapeIfNecessary(columnNamePattern);
throw new SQLFeatureNotSupportedException("privileges not supported");
}
@@ -1006,6 +1022,8 @@ public ResultSet getColumnPrivileges(String catalog, String schema, String table
public ResultSet getTablePrivileges(String catalog, String schemaPattern, String tableNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ tableNamePattern = escapeIfNecessary(tableNamePattern);
throw new SQLFeatureNotSupportedException("privileges not supported");
}
@@ -1168,6 +1186,8 @@ public boolean supportsBatchUpdates()
public ResultSet getUDTs(String catalog, String schemaPattern, String typeNamePattern, int[] types)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ typeNamePattern = escapeIfNecessary(typeNamePattern);
return selectEmpty("" +
"SELECT TYPE_CAT, TYPE_SCHEM, TYPE_NAME,\n" +
" CLASS_NAME, DATA_TYPE, REMARKS, BASE_TYPE\n" +
@@ -1214,6 +1234,8 @@ public boolean supportsGetGeneratedKeys()
public ResultSet getSuperTypes(String catalog, String schemaPattern, String typeNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ typeNamePattern = escapeIfNecessary(typeNamePattern);
return selectEmpty("" +
"SELECT TYPE_CAT, TYPE_SCHEM, TYPE_NAME,\n" +
" SUPERTYPE_CAT, SUPERTYPE_SCHEM, SUPERTYPE_NAME\n" +
@@ -1225,6 +1247,8 @@ public ResultSet getSuperTypes(String catalog, String schemaPattern, String type
public ResultSet getSuperTables(String catalog, String schemaPattern, String tableNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ tableNamePattern = escapeIfNecessary(tableNamePattern);
return selectEmpty("" +
"SELECT TABLE_CAT, TABLE_SCHEM, TABLE_NAME, SUPERTABLE_NAME\n" +
"FROM system.jdbc.super_tables\n" +
@@ -1235,6 +1259,9 @@ public ResultSet getSuperTables(String catalog, String schemaPattern, String tab
public ResultSet getAttributes(String catalog, String schemaPattern, String typeNamePattern, String attributeNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ typeNamePattern = escapeIfNecessary(typeNamePattern);
+ attributeNamePattern = escapeIfNecessary(attributeNamePattern);
return selectEmpty("" +
"SELECT TYPE_CAT, TYPE_SCHEM, TYPE_NAME, ATTR_NAME, DATA_TYPE,\n" +
" ATTR_TYPE_NAME, ATTR_SIZE, DECIMAL_DIGITS, NUM_PREC_RADIX, NULLABLE,\n" +
@@ -1332,6 +1359,7 @@ public RowIdLifetime getRowIdLifetime()
public ResultSet getSchemas(String catalog, String schemaPattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
StringBuilder query = new StringBuilder("" +
"SELECT TABLE_SCHEM, TABLE_CATALOG\n" +
"FROM system.jdbc.schemas");
@@ -1391,6 +1419,8 @@ public ResultSet getClientInfoProperties()
public ResultSet getFunctions(String catalog, String schemaPattern, String functionNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ functionNamePattern = escapeIfNecessary(functionNamePattern);
// TODO: implement this
throw new NotImplementedException("DatabaseMetaData", "getFunctions");
}
@@ -1399,6 +1429,9 @@ public ResultSet getFunctions(String catalog, String schemaPattern, String funct
public ResultSet getFunctionColumns(String catalog, String schemaPattern, String functionNamePattern, String columnNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ functionNamePattern = escapeIfNecessary(functionNamePattern);
+ columnNamePattern = escapeIfNecessary(columnNamePattern);
// TODO: implement this
throw new NotImplementedException("DatabaseMetaData", "getFunctionColumns");
}
@@ -1407,6 +1440,9 @@ public ResultSet getFunctionColumns(String catalog, String schemaPattern, String
public ResultSet getPseudoColumns(String catalog, String schemaPattern, String tableNamePattern, String columnNamePattern)
throws SQLException
{
+ schemaPattern = escapeIfNecessary(schemaPattern);
+ tableNamePattern = escapeIfNecessary(tableNamePattern);
+ columnNamePattern = escapeIfNecessary(columnNamePattern);
return selectEmpty("" +
"SELECT TABLE_CAT, TABLE_SCHEM, TABLE_NAME, COLUMN_NAME, DATA_TYPE,\n" +
" COLUMN_SIZE, DECIMAL_DIGITS, NUM_PREC_RADIX, COLUMN_USAGE, REMARKS,\n" +
@@ -1485,6 +1521,23 @@ private static void optionalStringInFilter(List filters, String columnNa
filters.add(filter.toString());
}
+ @Nullable
+ private String escapeIfNecessary(@Nullable String namePattern)
+ {
+ return escapeIfNecessary(assumeLiteralNamesInMetadataCallsForNonConformingClients, namePattern);
+ }
+
+ @Nullable
+ static String escapeIfNecessary(boolean assumeLiteralNamesInMetadataCallsForNonConformingClients, @Nullable String namePattern)
+ {
+ if (namePattern == null || !assumeLiteralNamesInMetadataCallsForNonConformingClients) {
+ return namePattern;
+ }
+ //noinspection ConstantConditions
+ verify(SEARCH_STRING_ESCAPE.equals("\\"));
+ return namePattern.replaceAll("[_%\\\\]", "\\\\$0");
+ }
+
private static void optionalStringLikeFilter(List filters, String columnName, String value)
{
if (value != null) {
diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java
index 9a2af8cb55d0..b4c8e6121d64 100644
--- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java
+++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java
@@ -53,11 +53,13 @@
import static io.trino.client.OkHttpUtil.tokenAuth;
import static io.trino.jdbc.ConnectionProperties.ACCESS_TOKEN;
import static io.trino.jdbc.ConnectionProperties.APPLICATION_NAME_PREFIX;
+import static io.trino.jdbc.ConnectionProperties.ASSUME_LITERAL_NAMES_IN_METADATA_CALLS_FOR_NON_CONFORMING_CLIENTS;
import static io.trino.jdbc.ConnectionProperties.CLIENT_INFO;
import static io.trino.jdbc.ConnectionProperties.CLIENT_TAGS;
import static io.trino.jdbc.ConnectionProperties.DISABLE_COMPRESSION;
import static io.trino.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION;
import static io.trino.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_TIMEOUT;
+import static io.trino.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_TOKEN_CACHE;
import static io.trino.jdbc.ConnectionProperties.EXTRA_CREDENTIALS;
import static io.trino.jdbc.ConnectionProperties.HTTP_PROXY;
import static io.trino.jdbc.ConnectionProperties.KERBEROS_CONFIG_PATH;
@@ -236,6 +238,12 @@ public boolean isCompressionDisabled()
return DISABLE_COMPRESSION.getValue(properties).orElse(false);
}
+ public boolean isAssumeLiteralNamesInMetadataCallsForNonConformingClients()
+ throws SQLException
+ {
+ return ASSUME_LITERAL_NAMES_IN_METADATA_CALLS_FOR_NON_CONFORMING_CLIENTS.getValue(properties).orElse(false);
+ }
+
public void setupClient(OkHttpClient.Builder builder)
throws SQLException
{
@@ -310,7 +318,9 @@ public void setupClient(OkHttpClient.Builder builder)
.map(value -> Duration.ofMillis(value.toMillis()))
.orElse(Duration.ofMinutes(2));
- ExternalAuthenticator authenticator = new ExternalAuthenticator(REDIRECT_HANDLER.get(), poller, timeout);
+ KnownTokenCache knownTokenCache = EXTERNAL_AUTHENTICATION_TOKEN_CACHE.getValue(properties).get();
+
+ ExternalAuthenticator authenticator = new ExternalAuthenticator(REDIRECT_HANDLER.get(), poller, knownTokenCache.create(), timeout);
builder.authenticator(authenticator);
builder.addInterceptor(authenticator);
diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestQueryExecutor.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestQueryExecutor.java
index d800eb6aeb16..d3154921f8cf 100644
--- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestQueryExecutor.java
+++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestQueryExecutor.java
@@ -29,6 +29,7 @@
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static io.airlift.json.JsonCodec.jsonCodec;
import static io.trino.client.NodeVersion.UNKNOWN;
+import static java.util.concurrent.TimeUnit.MINUTES;
import static org.testng.Assert.assertEquals;
@Test(singleThreaded = true)
@@ -57,7 +58,7 @@ public void teardown()
public void testGetServerInfo()
throws Exception
{
- ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));
+ ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(new Duration(2, MINUTES)));
server.enqueue(new MockResponse()
.addHeader(CONTENT_TYPE, "application/json")
@@ -67,7 +68,7 @@ public void testGetServerInfo()
ServerInfo actual = executor.getServerInfo(server.url("/v1/info").uri());
assertEquals(actual.getEnvironment(), "test");
- assertEquals(actual.getUptime(), Optional.of(Duration.valueOf("2m")));
+ assertEquals(actual.getUptime(), Optional.of(new Duration(2, MINUTES)));
assertEquals(server.getRequestCount(), 1);
assertEquals(server.takeRequest().getPath(), "/v1/info");
diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java
index 26efb600646b..924347f50f57 100644
--- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java
+++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java
@@ -987,7 +987,6 @@ public void testGetSuperTypes()
}
@Test
- @SuppressWarnings("resource")
public void testGetSchemasMetadataCalls()
throws Exception
{
@@ -995,6 +994,7 @@ public void testGetSchemasMetadataCalls()
// No filter
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getSchemas(null, null),
list("TABLE_CATALOG", "TABLE_SCHEM")),
@@ -1003,6 +1003,7 @@ public void testGetSchemasMetadataCalls()
// Equality predicate on catalog name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getSchemas(COUNTING_CATALOG, null),
list("TABLE_CATALOG", "TABLE_SCHEM")),
@@ -1015,6 +1016,7 @@ public void testGetSchemasMetadataCalls()
// Equality predicate on schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getSchemas(COUNTING_CATALOG, "test\\_schema%"),
list("TABLE_CATALOG", "TABLE_SCHEM")),
@@ -1026,6 +1028,7 @@ public void testGetSchemasMetadataCalls()
// LIKE predicate on schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getSchemas(COUNTING_CATALOG, "test_sch_ma1"),
list("TABLE_CATALOG", "TABLE_SCHEM")),
@@ -1035,6 +1038,7 @@ public void testGetSchemasMetadataCalls()
// Empty schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getSchemas(COUNTING_CATALOG, ""),
list("TABLE_CATALOG", "TABLE_SCHEM")),
@@ -1044,6 +1048,7 @@ public void testGetSchemasMetadataCalls()
// catalog does not exist
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getSchemas("wrong", null),
list("TABLE_CATALOG", "TABLE_SCHEM")),
@@ -1052,7 +1057,6 @@ public void testGetSchemasMetadataCalls()
}
@Test
- @SuppressWarnings("resource")
public void testGetTablesMetadataCalls()
throws Exception
{
@@ -1060,6 +1064,7 @@ public void testGetTablesMetadataCalls()
// No filter
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(null, null, null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1069,6 +1074,7 @@ public void testGetTablesMetadataCalls()
// Equality predicate on catalog name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, null, null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1078,6 +1084,7 @@ public void testGetTablesMetadataCalls()
// Equality predicate on schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test\\_schema1", null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1090,6 +1097,7 @@ public void testGetTablesMetadataCalls()
// LIKE predicate on schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test_sch_ma1", null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1103,6 +1111,7 @@ public void testGetTablesMetadataCalls()
// Equality predicate on table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, null, "test\\_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1115,6 +1124,7 @@ public void testGetTablesMetadataCalls()
// LIKE predicate on table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, null, "test_t_ble1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1127,14 +1137,17 @@ public void testGetTablesMetadataCalls()
// Equality predicate on schema name and table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test\\_schema1", "test\\_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
list(list(COUNTING_CATALOG, "test_schema1", "test_table1", "TABLE")),
- new MetadataCallsCount());
+ new MetadataCallsCount()
+ .withGetTableHandleCount(1));
// LIKE predicate on schema name and table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test_schema1", "test_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1145,6 +1158,7 @@ public void testGetTablesMetadataCalls()
// catalog does not exist
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables("wrong", null, null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1153,6 +1167,7 @@ public void testGetTablesMetadataCalls()
// empty schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "", null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1163,6 +1178,7 @@ public void testGetTablesMetadataCalls()
// empty table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, null, "", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1173,6 +1189,7 @@ public void testGetTablesMetadataCalls()
// no table types selected
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, null, null, new String[0]),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
@@ -1181,7 +1198,6 @@ public void testGetTablesMetadataCalls()
}
@Test
- @SuppressWarnings("resource")
public void testGetColumnsMetadataCalls()
throws Exception
{
@@ -1189,6 +1205,7 @@ public void testGetColumnsMetadataCalls()
// No filter
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(null, null, null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1199,6 +1216,7 @@ public void testGetColumnsMetadataCalls()
// Equality predicate on catalog name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, null, null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1209,6 +1227,7 @@ public void testGetColumnsMetadataCalls()
// Equality predicate on catalog name, schema name and table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "test\\_schema1", "test\\_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1221,6 +1240,7 @@ public void testGetColumnsMetadataCalls()
// Equality predicate on catalog name, schema name, table name and column name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "test\\_schema1", "test\\_table1", "column\\_17"),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1231,6 +1251,7 @@ public void testGetColumnsMetadataCalls()
// Equality predicate on catalog name, LIKE predicate on schema name, table name and column name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "test_schema1", "test_table1", "column_17"),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1242,6 +1263,7 @@ public void testGetColumnsMetadataCalls()
// LIKE predicate on schema name and table name, but no predicate on catalog name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(null, "test_schema1", "test_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1255,6 +1277,7 @@ public void testGetColumnsMetadataCalls()
// LIKE predicate on schema name, but no predicate on catalog name and table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(null, "test_schema1", null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1270,6 +1293,7 @@ public void testGetColumnsMetadataCalls()
// LIKE predicate on table name, but no predicate on catalog name and schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(null, null, "test_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1281,10 +1305,12 @@ public void testGetColumnsMetadataCalls()
new MetadataCallsCount()
.withListSchemasCount(3)
.withListTablesCount(8)
+ .withGetTableHandleCount(2)
.withGetColumnsCount(2));
// Equality predicate on schema name and table name, but no predicate on catalog name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(null, "test\\_schema1", "test\\_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1297,6 +1323,7 @@ public void testGetColumnsMetadataCalls()
// catalog does not exist
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns("wrong", null, null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1305,6 +1332,7 @@ public void testGetColumnsMetadataCalls()
// schema does not exist
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "wrong\\_schema1", "test\\_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1314,6 +1342,7 @@ public void testGetColumnsMetadataCalls()
// schema does not exist
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "wrong_schema1", "test_table1", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1325,6 +1354,7 @@ public void testGetColumnsMetadataCalls()
// empty schema name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "", null, null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1336,6 +1366,7 @@ public void testGetColumnsMetadataCalls()
// empty table name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, null, "", null),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1347,6 +1378,7 @@ public void testGetColumnsMetadataCalls()
// empty column name
assertMetadataCalls(
+ connection,
readMetaData(
databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, null, null, ""),
list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
@@ -1357,6 +1389,102 @@ public void testGetColumnsMetadataCalls()
.withGetColumnsCount(3000));
}
+ @Test
+ public void testAssumeLiteralMetadataCalls()
+ throws Exception
+ {
+ try (Connection connection = DriverManager.getConnection(
+ format("jdbc:trino://%s?assumeLiteralNamesInMetadataCallsForNonConformingClients=true", server.getAddress()),
+ "admin",
+ null)) {
+ // getTables's schema name pattern treated as literal
+ assertMetadataCalls(
+ connection,
+ readMetaData(
+ databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test_schema1", null, null),
+ list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
+ countingMockConnector.getAllTables()
+ .filter(schemaTableName -> schemaTableName.getSchemaName().equals("test_schema1"))
+ .map(schemaTableName -> list(COUNTING_CATALOG, schemaTableName.getSchemaName(), schemaTableName.getTableName(), "TABLE"))
+ .collect(toImmutableList()),
+ new MetadataCallsCount()
+ .withListSchemasCount(0)
+ .withListTablesCount(1));
+
+ // getTables's schema and table name patterns treated as literals
+ assertMetadataCalls(
+ connection,
+ readMetaData(
+ databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test_schema1", "test_table1", null),
+ list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
+ list(list(COUNTING_CATALOG, "test_schema1", "test_table1", "TABLE")),
+ new MetadataCallsCount()
+ .withGetTableHandleCount(1));
+
+ // no matches in getTables call as table name pattern treated as literal
+ assertMetadataCalls(
+ connection,
+ readMetaData(
+ databaseMetaData -> databaseMetaData.getTables(COUNTING_CATALOG, "test_schema_", null, null),
+ list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE")),
+ list(),
+ new MetadataCallsCount()
+ .withListTablesCount(1));
+
+ // getColumns's schema and table name patterns treated as literals
+ assertMetadataCalls(
+ connection,
+ readMetaData(
+ databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "test_schema1", "test_table1", null),
+ list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
+ IntStream.range(0, 100)
+ .mapToObj(i -> list(COUNTING_CATALOG, "test_schema1", "test_table1", "column_" + i, "varchar"))
+ .collect(toImmutableList()),
+ new MetadataCallsCount()
+ .withListTablesCount(1)
+ .withGetColumnsCount(1));
+
+ // getColumns's schema, table and column name patterns treated as literals
+ assertMetadataCalls(
+ connection,
+ readMetaData(
+ databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "test_schema1", "test_table1", "column_17"),
+ list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
+ list(list(COUNTING_CATALOG, "test_schema1", "test_table1", "column_17", "varchar")),
+ new MetadataCallsCount()
+ .withListTablesCount(1)
+ .withGetColumnsCount(1));
+
+ // no matches in getColumns call as table name pattern treated as literal
+ assertMetadataCalls(
+ connection,
+ readMetaData(
+ databaseMetaData -> databaseMetaData.getColumns(COUNTING_CATALOG, "test_schema1", "test_table_", null),
+ list("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME")),
+ list(),
+ new MetadataCallsCount()
+ .withListTablesCount(1));
+ }
+ }
+
+ @Test
+ public void testEscapeIfNecessary()
+ {
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(false, null), null);
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(false, "a"), "a");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(false, "abc_def"), "abc_def");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(false, "abc__de_f"), "abc__de_f");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(false, "abc%def"), "abc%def");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(false, "abc\\_def"), "abc\\_def");
+
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, null), null);
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "a"), "a");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "abc_def"), "abc\\_def");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "abc__de_f"), "abc\\_\\_de\\_f");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "abc%def"), "abc\\%def");
+ assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "abc\\_def"), "abc\\\\\\_def");
+ }
+
private static void assertColumnSpec(ResultSet rs, int dataType, Long precision, Long numPrecRadix, String typeName)
throws SQLException
{
@@ -1396,19 +1524,23 @@ private Set captureQueries(Callable> action)
.collect(toImmutableSet());
}
- private void assertMetadataCalls(MetaDataCallback extends Collection>> callback, MetadataCallsCount expectedMetadataCallsCount)
- throws Exception
+ private void assertMetadataCalls(Connection connection, MetaDataCallback extends Collection>> callback, MetadataCallsCount expectedMetadataCallsCount)
{
assertMetadataCalls(
+ connection,
callback,
actual -> {},
expectedMetadataCallsCount);
}
- private void assertMetadataCalls(MetaDataCallback extends Collection>> callback, Collection> expected, MetadataCallsCount expectedMetadataCallsCount)
- throws Exception
+ private void assertMetadataCalls(
+ Connection connection,
+ MetaDataCallback extends Collection>> callback,
+ Collection> expected,
+ MetadataCallsCount expectedMetadataCallsCount)
{
assertMetadataCalls(
+ connection,
callback,
actual -> assertThat(ImmutableMultiset.copyOf(requireNonNull(actual, "actual is null")))
.isEqualTo(ImmutableMultiset.copyOf(requireNonNull(expected, "expected is null"))),
@@ -1416,23 +1548,20 @@ private void assertMetadataCalls(MetaDataCallback extends Collection>> callback,
Consumer>> resultsVerification,
MetadataCallsCount expectedMetadataCallsCount)
- throws Exception
{
- MetadataCallsCount actualMetadataCallsCount;
- try (Connection connection = createConnection()) {
- actualMetadataCallsCount = countingMockConnector.runCounting(() -> {
- try {
- Collection> actual = callback.apply(connection.getMetaData());
- resultsVerification.accept(actual);
- }
- catch (SQLException e) {
- throw new RuntimeException(e);
- }
- });
- }
+ MetadataCallsCount actualMetadataCallsCount = countingMockConnector.runCounting(() -> {
+ try {
+ Collection> actual = callback.apply(connection.getMetaData());
+ resultsVerification.accept(actual);
+ }
+ catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ });
assertEquals(actualMetadataCallsCount, expectedMetadataCallsCount);
}
diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverImpersonateUser.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverImpersonateUser.java
index 30876bc445cd..1a23e854a4e3 100644
--- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverImpersonateUser.java
+++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverImpersonateUser.java
@@ -23,6 +23,9 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import java.security.Principal;
import java.sql.Connection;
import java.sql.DriverManager;
@@ -47,9 +50,13 @@ public class TestTrinoDriverImpersonateUser
@BeforeClass
public void setup()
+ throws IOException
{
+ Path passwordConfigDummy = Files.createTempFile("passwordConfigDummy", null);
+ passwordConfigDummy.toFile().deleteOnExit();
server = TestingTrinoServer.builder()
.setProperties(ImmutableMap.builder()
+ .put("password-authenticator.config-files", passwordConfigDummy.toString())
.put("http-server.authentication.type", "password")
.put("http-server.https.enabled", "true")
.put("http-server.https.keystore.path", getResource("localhost.keystore").getPath())
@@ -57,7 +64,7 @@ public void setup()
.build())
.build();
- server.getInstance(Key.get(PasswordAuthenticatorManager.class)).setAuthenticator(TestTrinoDriverImpersonateUser::authenticate);
+ server.getInstance(Key.get(PasswordAuthenticatorManager.class)).setAuthenticators(TestTrinoDriverImpersonateUser::authenticate);
}
private static Principal authenticate(String user, String password)
diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml
index 83baff416725..521dee25ed1b 100644
--- a/core/trino-main/pom.xml
+++ b/core/trino-main/pom.xml
@@ -5,7 +5,7 @@
io.trino
trino-root
- 355-SNAPSHOT
+ 356-SNAPSHOT
../../pom.xml
diff --git a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java
index d276de194c89..ae193965100e 100644
--- a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java
+++ b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java
@@ -44,6 +44,7 @@
import io.trino.operator.TaskStats;
import io.trino.server.BasicQueryInfo;
import io.trino.spi.QueryId;
+import io.trino.spi.eventlistener.OutputColumnMetadata;
import io.trino.spi.eventlistener.QueryCompletedEvent;
import io.trino.spi.eventlistener.QueryContext;
import io.trino.spi.eventlistener.QueryCreatedEvent;
@@ -56,6 +57,7 @@
import io.trino.spi.eventlistener.StageCpuDistribution;
import io.trino.spi.resourcegroups.QueryType;
import io.trino.spi.resourcegroups.ResourceGroupId;
+import io.trino.sql.analyzer.Analysis;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
@@ -76,6 +78,8 @@
import java.util.OptionalLong;
import java.util.stream.Collectors;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.execution.QueryState.QUEUED;
import static io.trino.execution.StageInfo.getAllStages;
import static io.trino.sql.planner.planprinter.PlanPrinter.textDistributedPlan;
@@ -362,11 +366,21 @@ private static QueryIOMetadata getQueryIOMetadata(QueryInfo queryInfo)
.map(TableFinishInfo.class::cast)
.findFirst();
+ Optional> outputColumnsMetadata = queryInfo.getOutput().get().getColumns()
+ .map(columns -> columns.stream()
+ .map(column -> new OutputColumnMetadata(
+ column.getColumn().getName(),
+ column.getSourceColumns().stream()
+ .map(Analysis.SourceColumn::getColumnDetail)
+ .collect(toImmutableSet())))
+ .collect(toImmutableList()));
+
output = Optional.of(
new QueryOutputMetadata(
queryInfo.getOutput().get().getCatalogName(),
queryInfo.getOutput().get().getSchema(),
queryInfo.getOutput().get().getTable(),
+ outputColumnsMetadata,
tableFinishInfo.map(TableFinishInfo::getConnectorOutputMetadata),
tableFinishInfo.map(TableFinishInfo::isJsonLengthLimitExceeded)));
}
diff --git a/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java b/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java
index fa08dde438a2..a26bb61e4885 100644
--- a/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java
+++ b/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java
@@ -99,7 +99,7 @@ public ListenableFuture> execute(
.map(field -> new ConnectorMaterializedViewDefinition.Column(field.getName().get(), field.getType().getTypeId()))
.collect(toImmutableList());
- Optional owner = Optional.of(session.getUser());
+ String owner = session.getUser();
CatalogName catalogName = metadata.getCatalogHandle(session, name.getCatalogName())
.orElseThrow(() -> new TrinoException(NOT_FOUND, "Catalog does not exist: " + name.getCatalogName()));
diff --git a/core/trino-main/src/main/java/io/trino/execution/CreateTableTask.java b/core/trino-main/src/main/java/io/trino/execution/CreateTableTask.java
index 84ffa69d33de..0ea34af2dcf4 100644
--- a/core/trino-main/src/main/java/io/trino/execution/CreateTableTask.java
+++ b/core/trino-main/src/main/java/io/trino/execution/CreateTableTask.java
@@ -16,6 +16,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import io.trino.Session;
import io.trino.connector.CatalogName;
@@ -31,6 +32,8 @@
import io.trino.spi.security.AccessDeniedException;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeNotFoundException;
+import io.trino.sql.analyzer.Output;
+import io.trino.sql.analyzer.OutputColumn;
import io.trino.sql.tree.ColumnDefinition;
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.Expression;
@@ -47,8 +50,10 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.function.Consumer;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static io.trino.metadata.MetadataUtil.createQualifiedObjectName;
@@ -96,11 +101,11 @@ public ListenableFuture> execute(
List parameters,
WarningCollector warningCollector)
{
- return internalExecute(statement, metadata, accessControl, stateMachine.getSession(), parameters);
+ return internalExecute(statement, metadata, accessControl, stateMachine.getSession(), parameters, output -> stateMachine.setOutput(Optional.of(output)));
}
@VisibleForTesting
- ListenableFuture> internalExecute(CreateTable statement, Metadata metadata, AccessControl accessControl, Session session, List parameters)
+ ListenableFuture> internalExecute(CreateTable statement, Metadata metadata, AccessControl accessControl, Session session, List parameters, Consumer outputConsumer)
{
checkArgument(!statement.getElements().isEmpty(), "no columns for table");
@@ -240,6 +245,13 @@ else if (element instanceof LikeClause) {
throw e;
}
}
+ outputConsumer.accept(new Output(
+ tableName.getCatalogName(),
+ tableName.getSchemaName(),
+ tableName.getObjectName(),
+ Optional.of(tableMetadata.getColumns().stream()
+ .map(column -> new OutputColumn(new Column(column.getName(), column.getType().toString()), ImmutableSet.of()))
+ .collect(toImmutableList()))));
return immediateFuture(null);
}
diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java
index 4f2512edee59..79b59e7e5d69 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java
@@ -192,7 +192,6 @@
import io.trino.operator.scalar.timestamp.TimestampToTimestampWithTimeZoneCast;
import io.trino.operator.scalar.timestamp.TimestampToVarcharCast;
import io.trino.operator.scalar.timestamp.ToIso8601;
-import io.trino.operator.scalar.timestamp.ToUnixTime;
import io.trino.operator.scalar.timestamp.VarcharToTimestampCast;
import io.trino.operator.scalar.timestamp.WithTimeZone;
import io.trino.operator.scalar.timestamptz.AtTimeZone;
@@ -632,7 +631,6 @@ public FunctionRegistry(
.scalar(VarcharToTimestampCast.class)
.scalar(LocalTimestamp.class)
.scalar(DateTrunc.class)
- .scalar(ToUnixTime.class)
.scalar(HumanReadableSeconds.class)
.scalar(ToIso8601.class)
.scalar(WithTimeZone.class)
diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java
index f6f33ec33588..bb76cd1e54a8 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java
@@ -114,8 +114,19 @@ public interface Metadata
Optional getInfo(Session session, TableHandle handle);
+ /**
+ * Return table schema definition for the specified table handle.
+ * Table schema definition is a set of information
+ * required by semantic analyzer to analyze the query.
+ * @see {@link #getTableMetadata(Session, TableHandle)}
+ *
+ * @throws RuntimeException if table handle is no longer valid
+ */
+ TableSchema getTableSchema(Session session, TableHandle tableHandle);
+
/**
* Return the metadata for the specified table handle.
+ * @see {@link #getTableSchema(Session, TableHandle)} which is less expsensive.
*
* @throws RuntimeException if table handle is no longer valid
*/
diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
index 9807e9d9a495..b65adbca7665 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
@@ -15,11 +15,15 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
+import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.slice.Slice;
import io.trino.Session;
import io.trino.client.NodeVersion;
@@ -69,6 +73,7 @@
import io.trino.spi.connector.ConnectorTableLayoutResult;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.ConnectorTableProperties;
+import io.trino.spi.connector.ConnectorTableSchema;
import io.trino.spi.connector.ConnectorTransactionHandle;
import io.trino.spi.connector.ConnectorViewDefinition;
import io.trino.spi.connector.Constraint;
@@ -200,6 +205,9 @@ public final class MetadataManager
private final ResolvedFunctionDecoder functionDecoder;
+ private final LoadingCache operatorCache;
+ private final LoadingCache coercionCache;
+
@Inject
public MetadataManager(
FeaturesConfig featuresConfig,
@@ -247,6 +255,23 @@ public MetadataManager(
verifyTypes();
functionDecoder = new ResolvedFunctionDecoder(this::getType);
+
+ operatorCache = CacheBuilder.newBuilder()
+ .maximumSize(1000)
+ .build(CacheLoader.from(key -> {
+ String name = mangleOperatorName(key.getOperatorType());
+ return resolveFunction(QualifiedName.of(name), fromTypes(key.getArgumentTypes()));
+ }));
+
+ coercionCache = CacheBuilder.newBuilder()
+ .maximumSize(1000)
+ .build(CacheLoader.from(key -> {
+ String name = mangleOperatorName(key.getOperatorType());
+ Type fromType = key.getFromType();
+ Type toType = key.getToType();
+ Signature signature = new Signature(name, toType.getTypeSignature(), ImmutableList.of(fromType.getTypeSignature()));
+ return resolve(functionResolver.resolveCoercion(functions.get(QualifiedName.of(name)), signature));
+ }));
}
public static MetadataManager createTestMetadataManager()
@@ -508,6 +533,16 @@ public Optional getInfo(Session session, TableHandle handle)
return metadata.getInfo(handle.getConnectorHandle());
}
+ @Override
+ public TableSchema getTableSchema(Session session, TableHandle tableHandle)
+ {
+ CatalogName catalogName = tableHandle.getCatalogName();
+ ConnectorMetadata metadata = getMetadata(session, catalogName);
+ ConnectorTableSchema tableSchema = metadata.getTableSchema(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle());
+
+ return new TableSchema(catalogName, tableSchema);
+ }
+
@Override
public TableMetadata getTableMetadata(Session session, TableHandle tableHandle)
{
@@ -1873,11 +1908,15 @@ public ResolvedFunction resolveOperator(OperatorType operatorType, List extend
throws OperatorNotFoundException
{
try {
- return resolveFunction(QualifiedName.of(mangleOperatorName(operatorType)), fromTypes(argumentTypes));
+ return operatorCache.getUnchecked(new OperatorCacheKey(operatorType, argumentTypes));
}
- catch (TrinoException e) {
- if (e.getErrorCode().getCode() == FUNCTION_NOT_FOUND.toErrorCode().getCode()) {
- throw new OperatorNotFoundException(operatorType, argumentTypes, e);
+ catch (UncheckedExecutionException e) {
+ if (e.getCause() instanceof TrinoException) {
+ TrinoException cause = (TrinoException) e.getCause();
+ if (cause.getErrorCode().getCode() == FUNCTION_NOT_FOUND.toErrorCode().getCode()) {
+ throw new OperatorNotFoundException(operatorType, argumentTypes, cause);
+ }
+ throw cause;
}
throw e;
}
@@ -1888,12 +1927,15 @@ public ResolvedFunction getCoercion(OperatorType operatorType, Type fromType, Ty
{
checkArgument(operatorType == OperatorType.CAST || operatorType == OperatorType.SATURATED_FLOOR_CAST);
try {
- String name = mangleOperatorName(operatorType);
- return resolve(functionResolver.resolveCoercion(functions.get(QualifiedName.of(name)), new Signature(name, toType.getTypeSignature(), ImmutableList.of(fromType.getTypeSignature()))));
+ return coercionCache.getUnchecked(new CoercionCacheKey(operatorType, fromType, toType));
}
- catch (TrinoException e) {
- if (e.getErrorCode().getCode() == FUNCTION_IMPLEMENTATION_MISSING.toErrorCode().getCode()) {
- throw new OperatorNotFoundException(operatorType, ImmutableList.of(fromType), toType.getTypeSignature(), e);
+ catch (UncheckedExecutionException e) {
+ if (e.getCause() instanceof TrinoException) {
+ TrinoException cause = (TrinoException) e.getCause();
+ if (cause.getErrorCode().getCode() == FUNCTION_IMPLEMENTATION_MISSING.toErrorCode().getCode()) {
+ throw new OperatorNotFoundException(operatorType, ImmutableList.of(fromType), toType.getTypeSignature(), cause);
+ }
+ throw cause;
}
throw e;
}
@@ -2303,4 +2345,96 @@ private synchronized void finish()
}
}
}
+
+ private static class OperatorCacheKey
+ {
+ private final OperatorType operatorType;
+ private final List extends Type> argumentTypes;
+
+ private OperatorCacheKey(OperatorType operatorType, List extends Type> argumentTypes)
+ {
+ this.operatorType = requireNonNull(operatorType, "operatorType is null");
+ this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null"));
+ }
+
+ public OperatorType getOperatorType()
+ {
+ return operatorType;
+ }
+
+ public List extends Type> getArgumentTypes()
+ {
+ return argumentTypes;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(operatorType, argumentTypes);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof OperatorCacheKey)) {
+ return false;
+ }
+ OperatorCacheKey other = (OperatorCacheKey) obj;
+ return Objects.equals(this.operatorType, other.operatorType) &&
+ Objects.equals(this.argumentTypes, other.argumentTypes);
+ }
+ }
+
+ private static class CoercionCacheKey
+ {
+ private final OperatorType operatorType;
+ private final Type fromType;
+ private final Type toType;
+
+ private CoercionCacheKey(OperatorType operatorType, Type fromType, Type toType)
+ {
+ this.operatorType = requireNonNull(operatorType, "operatorType is null");
+ this.fromType = requireNonNull(fromType, "fromType is null");
+ this.toType = requireNonNull(toType, "toType is null");
+ }
+
+ public OperatorType getOperatorType()
+ {
+ return operatorType;
+ }
+
+ public Type getFromType()
+ {
+ return fromType;
+ }
+
+ public Type getToType()
+ {
+ return toType;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(operatorType, fromType, toType);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof CoercionCacheKey)) {
+ return false;
+ }
+ CoercionCacheKey other = (CoercionCacheKey) obj;
+ return Objects.equals(this.operatorType, other.operatorType) &&
+ Objects.equals(this.fromType, other.fromType) &&
+ Objects.equals(this.toType, other.toType);
+ }
+ }
}
diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableSchema.java b/core/trino-main/src/main/java/io/trino/metadata/TableSchema.java
new file mode 100644
index 000000000000..a29498a1cbd0
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/metadata/TableSchema.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.metadata;
+
+import io.trino.connector.CatalogName;
+import io.trino.spi.connector.ColumnSchema;
+import io.trino.spi.connector.ConnectorTableSchema;
+import io.trino.spi.connector.SchemaTableName;
+
+import java.util.List;
+
+import static com.google.common.collect.MoreCollectors.toOptional;
+import static java.util.Objects.requireNonNull;
+
+public final class TableSchema
+{
+ private final CatalogName catalogName;
+ private final ConnectorTableSchema tableSchema;
+
+ public TableSchema(CatalogName catalogName, ConnectorTableSchema tableSchema)
+ {
+ requireNonNull(catalogName, "catalog is null");
+ requireNonNull(tableSchema, "metadata is null");
+
+ this.catalogName = catalogName;
+ this.tableSchema = tableSchema;
+ }
+
+ public QualifiedObjectName getQualifiedName()
+ {
+ return new QualifiedObjectName(catalogName.getCatalogName(), tableSchema.getTable().getSchemaName(), tableSchema.getTable().getTableName());
+ }
+
+ public CatalogName getCatalogName()
+ {
+ return catalogName;
+ }
+
+ public ConnectorTableSchema getTableSchema()
+ {
+ return tableSchema;
+ }
+
+ public SchemaTableName getTable()
+ {
+ return tableSchema.getTable();
+ }
+
+ public List getColumns()
+ {
+ return tableSchema.getColumns();
+ }
+
+ public ColumnSchema getColumn(String name)
+ {
+ return tableSchema.getColumns().stream()
+ .filter(columnMetadata -> columnMetadata.getName().equals(name))
+ .collect(toOptional())
+ .orElseThrow(() -> new IllegalArgumentException("Invalid column name: " + name));
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperator.java b/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperator.java
index 4f22d5493845..a6839b0a8881 100644
--- a/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperator.java
+++ b/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperator.java
@@ -72,6 +72,7 @@ public class LookupJoinOperator
implements AdapterWorkProcessorOperator
{
private final ListenableFuture lookupSourceProviderFuture;
+ private final boolean waitForBuild;
private final PageBuffer pageBuffer;
private final WorkProcessor pages;
private final SpillingJoinProcessor joinProcessor;
@@ -82,6 +83,7 @@ public class LookupJoinOperator
List buildOutputTypes,
JoinType joinType,
boolean outputSingleMatch,
+ boolean waitForBuild,
LookupSourceFactory lookupSourceFactory,
JoinProbeFactory joinProbeFactory,
Runnable afterClose,
@@ -92,8 +94,9 @@ public class LookupJoinOperator
Optional> sourcePages)
{
this.statisticsCounter = new JoinStatisticsCounter(joinType);
+ this.waitForBuild = waitForBuild;
lookupSourceProviderFuture = lookupSourceFactory.createLookupSourceProvider();
- pageBuffer = new PageBuffer(lookupSourceProviderFuture);
+ pageBuffer = new PageBuffer();
joinProcessor = new SpillingJoinProcessor(
processorContext,
afterClose,
@@ -102,6 +105,7 @@ public class LookupJoinOperator
buildOutputTypes,
joinType,
outputSingleMatch,
+ waitForBuild,
hashGenerator,
joinProbeFactory,
lookupSourceFactory,
@@ -121,7 +125,7 @@ public Optional getOperatorInfo()
@Override
public boolean needsInput()
{
- return lookupSourceProviderFuture.isDone() && pageBuffer.isEmpty() && !pageBuffer.isFinished();
+ return (!waitForBuild || lookupSourceProviderFuture.isDone()) && pageBuffer.isEmpty() && !pageBuffer.isFinished();
}
@Override
@@ -518,9 +522,11 @@ private static class SpillingJoinProcessor
private final List buildOutputTypes;
private final JoinType joinType;
private final boolean outputSingleMatch;
+ private final boolean waitForBuild;
private final HashGenerator hashGenerator;
private final JoinProbeFactory joinProbeFactory;
private final LookupSourceFactory lookupSourceFactory;
+ private final ListenableFuture lookupSourceProvider;
private final JoinStatisticsCounter statisticsCounter;
private final PageJoiner sourcePagesJoiner;
private final WorkProcessor joinedSourcePages;
@@ -544,6 +550,7 @@ private SpillingJoinProcessor(
List buildOutputTypes,
JoinType joinType,
boolean outputSingleMatch,
+ boolean waitForBuild,
HashGenerator hashGenerator,
JoinProbeFactory joinProbeFactory,
LookupSourceFactory lookupSourceFactory,
@@ -559,9 +566,11 @@ private SpillingJoinProcessor(
this.buildOutputTypes = requireNonNull(buildOutputTypes, "buildOutputTypes is null");
this.joinType = requireNonNull(joinType, "joinType is null");
this.outputSingleMatch = outputSingleMatch;
+ this.waitForBuild = waitForBuild;
this.hashGenerator = requireNonNull(hashGenerator, "hashGenerator is null");
this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null");
this.lookupSourceFactory = requireNonNull(lookupSourceFactory, "lookupSourceFactory is null");
+ this.lookupSourceProvider = requireNonNull(lookupSourceProvider, "lookupSourceProvider is null");
this.statisticsCounter = requireNonNull(statisticsCounter, "statisticsCounter is null");
sourcePagesJoiner = new PageJoiner(
processorContext,
@@ -582,6 +591,12 @@ private SpillingJoinProcessor(
@Override
public ProcessState> process()
{
+ // wait for build side to be completed before fetching any probe data
+ // TODO: fix support for probe short-circuit: https://github.com/trinodb/trino/issues/3957
+ if (waitForBuild && !lookupSourceProvider.isDone()) {
+ return ProcessState.blocked(lookupSourceProvider);
+ }
+
if (!joinedSourcePages.isFinished()) {
return ProcessState.ofResult(joinedSourcePages);
}
diff --git a/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperatorFactory.java
index 196a48e50a8d..d8e5d828e4b8 100644
--- a/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperatorFactory.java
+++ b/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperatorFactory.java
@@ -46,6 +46,7 @@ public class LookupJoinOperatorFactory
private final List buildOutputTypes;
private final JoinType joinType;
private final boolean outputSingleMatch;
+ private final boolean waitForBuild;
private final JoinProbeFactory joinProbeFactory;
private final Optional outerOperatorFactoryResult;
private final JoinBridgeManager extends LookupSourceFactory> joinBridgeManager;
@@ -64,6 +65,7 @@ public LookupJoinOperatorFactory(
List buildOutputTypes,
JoinType joinType,
boolean outputSingleMatch,
+ boolean waitForBuild,
JoinProbeFactory joinProbeFactory,
BlockTypeOperators blockTypeOperators,
OptionalInt totalOperatorsCount,
@@ -77,6 +79,7 @@ public LookupJoinOperatorFactory(
this.buildOutputTypes = ImmutableList.copyOf(requireNonNull(buildOutputTypes, "buildOutputTypes is null"));
this.joinType = requireNonNull(joinType, "joinType is null");
this.outputSingleMatch = outputSingleMatch;
+ this.waitForBuild = waitForBuild;
this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null");
this.joinBridgeManager = lookupSourceFactoryManager;
@@ -123,6 +126,7 @@ private LookupJoinOperatorFactory(LookupJoinOperatorFactory other)
buildOutputTypes = other.buildOutputTypes;
joinType = other.joinType;
outputSingleMatch = other.outputSingleMatch;
+ waitForBuild = other.waitForBuild;
joinProbeFactory = other.joinProbeFactory;
joinBridgeManager = other.joinBridgeManager;
outerOperatorFactoryResult = other.outerOperatorFactoryResult;
@@ -193,6 +197,7 @@ public WorkProcessorOperator create(ProcessorContext processorContext, WorkProce
buildOutputTypes,
joinType,
outputSingleMatch,
+ waitForBuild,
lookupSourceFactory,
joinProbeFactory,
() -> joinBridgeManager.probeOperatorClosed(processorContext.getLifespan()),
@@ -215,6 +220,7 @@ public AdapterWorkProcessorOperator createAdapterOperator(ProcessorContext proce
buildOutputTypes,
joinType,
outputSingleMatch,
+ waitForBuild,
lookupSourceFactory,
joinProbeFactory,
() -> joinBridgeManager.probeOperatorClosed(processorContext.getLifespan()),
diff --git a/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperators.java b/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperators.java
index 857009866945..ebce8869f22b 100644
--- a/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperators.java
+++ b/core/trino-main/src/main/java/io/trino/operator/LookupJoinOperators.java
@@ -49,6 +49,7 @@ public OperatorFactory innerJoin(
JoinBridgeManager extends LookupSourceFactory> lookupSourceFactory,
List probeTypes,
boolean outputSingleMatch,
+ boolean waitForBuild,
List probeJoinChannel,
OptionalInt probeHashChannel,
Optional> probeOutputChannels,
@@ -66,6 +67,7 @@ public OperatorFactory innerJoin(
probeOutputChannels.orElse(rangeList(probeTypes.size())),
JoinType.INNER,
outputSingleMatch,
+ waitForBuild,
totalOperatorsCount,
partitioningSpillerFactory,
blockTypeOperators);
@@ -94,6 +96,7 @@ public OperatorFactory probeOuterJoin(
probeOutputChannels.orElse(rangeList(probeTypes.size())),
JoinType.PROBE_OUTER,
outputSingleMatch,
+ false,
totalOperatorsCount,
partitioningSpillerFactory,
blockTypeOperators);
@@ -104,6 +107,7 @@ public OperatorFactory lookupOuterJoin(
PlanNodeId planNodeId,
JoinBridgeManager extends LookupSourceFactory> lookupSourceFactory,
List probeTypes,
+ boolean waitForBuild,
List probeJoinChannel,
OptionalInt probeHashChannel,
Optional> probeOutputChannels,
@@ -121,6 +125,7 @@ public OperatorFactory lookupOuterJoin(
probeOutputChannels.orElse(rangeList(probeTypes.size())),
JoinType.LOOKUP_OUTER,
false,
+ waitForBuild,
totalOperatorsCount,
partitioningSpillerFactory,
blockTypeOperators);
@@ -148,6 +153,7 @@ public OperatorFactory fullOuterJoin(
probeOutputChannels.orElse(rangeList(probeTypes.size())),
JoinType.FULL_OUTER,
false,
+ false,
totalOperatorsCount,
partitioningSpillerFactory,
blockTypeOperators);
@@ -170,6 +176,7 @@ private OperatorFactory createJoinOperatorFactory(
List probeOutputChannels,
JoinType joinType,
boolean outputSingleMatch,
+ boolean waitForBuild,
OptionalInt totalOperatorsCount,
PartitioningSpillerFactory partitioningSpillerFactory,
BlockTypeOperators blockTypeOperators)
@@ -187,6 +194,7 @@ private OperatorFactory createJoinOperatorFactory(
lookupSourceFactoryManager.getBuildOutputTypes(),
joinType,
outputSingleMatch,
+ waitForBuild,
new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel),
blockTypeOperators,
totalOperatorsCount,
diff --git a/core/trino-main/src/main/java/io/trino/operator/PartitionedLookupSourceFactory.java b/core/trino-main/src/main/java/io/trino/operator/PartitionedLookupSourceFactory.java
index 50ca998a707d..19ef6fbe79d8 100644
--- a/core/trino-main/src/main/java/io/trino/operator/PartitionedLookupSourceFactory.java
+++ b/core/trino-main/src/main/java/io/trino/operator/PartitionedLookupSourceFactory.java
@@ -212,7 +212,7 @@ public void setPartitionSpilledLookupSourceHandle(int partitionIndex, SpilledLoo
lock.writeLock().lock();
try {
- if (destroyed.isDone()) {
+ if (partitionsNoLongerNeeded.isDone()) {
spilledLookupSourceHandle.dispose();
return;
}
@@ -302,6 +302,14 @@ public ListenableFuture>> finishPr
try {
if (!spillingInfo.hasSpilled()) {
finishedProbeOperators++;
+ if (lookupJoinsCount.isPresent()) {
+ checkState(finishedProbeOperators <= lookupJoinsCount.getAsInt(), "%s probe operators finished out of %s declared", finishedProbeOperators, lookupJoinsCount.getAsInt());
+ if (finishedProbeOperators == lookupJoinsCount.getAsInt()) {
+ // We can dispose partitions now since right outer is not supported with spill and lookupJoinsCount should be absent
+ freePartitions();
+ }
+ }
+
return immediateFuture(new PartitionedConsumption<>(
1,
emptyList(),
@@ -322,7 +330,7 @@ public ListenableFuture>> finishPr
finishedProbeOperators++;
if (finishedProbeOperators == operatorsCount) {
- // We can dispose partitions now since as right outer is not supported with spill
+ // We can dispose partitions now since right outer is not supported with spill
freePartitions();
verify(!partitionedConsumption.isDone());
partitionedConsumption.set(new PartitionedConsumption<>(
diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/DateTimeFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/DateTimeFunctions.java
index 188e68462eeb..b6a96ad32876 100644
--- a/core/trino-main/src/main/java/io/trino/operator/scalar/DateTimeFunctions.java
+++ b/core/trino-main/src/main/java/io/trino/operator/scalar/DateTimeFunctions.java
@@ -24,7 +24,6 @@
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
-import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.TimeZoneKey;
@@ -55,7 +54,6 @@
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_SECOND;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.rescale;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger;
-import static io.trino.type.DateTimes.MICROSECONDS_PER_SECOND;
import static io.trino.type.DateTimes.PICOSECONDS_PER_NANOSECOND;
import static io.trino.type.DateTimes.PICOSECONDS_PER_SECOND;
import static io.trino.type.DateTimes.scaleEpochMillisToMicros;
@@ -121,10 +119,11 @@ public static Slice currentTimeZone(ConnectorSession session)
}
@ScalarFunction("from_unixtime")
- @SqlType("timestamp(3)")
- public static long fromUnixTime(@SqlType(StandardTypes.DOUBLE) double unixTime)
+ @SqlType("timestamp(3) with time zone")
+ public static long fromUnixTime(ConnectorSession session, @SqlType(StandardTypes.DOUBLE) double unixTime)
{
- return Math.round(unixTime * MICROSECONDS_PER_SECOND);
+ // TODO (https://github.com/trinodb/trino/issues/5781)
+ return packDateTimeWithZone(Math.round(unixTime * 1000), session.getTimeZoneKey());
}
@ScalarFunction("from_unixtime")
@@ -155,9 +154,10 @@ public static final class FromUnixtimeNanosDecimal
private FromUnixtimeNanosDecimal() {}
@LiteralParameters({"p", "s"})
- @SqlType("timestamp(9)")
- public static LongTimestamp fromLong(@LiteralParameter("s") long scale, @SqlType("decimal(p, s)") Slice unixTimeNanos)
+ @SqlType("timestamp(9) with time zone")
+ public static LongTimestampWithTimeZone fromLong(@LiteralParameter("s") long scale, ConnectorSession session, @SqlType("decimal(p, s)") Slice unixTimeNanos)
{
+ // TODO (https://github.com/trinodb/trino/issues/5781)
BigInteger unixTimeNanosInt = unscaledDecimalToBigInteger(rescale(unixTimeNanos, -(int) scale));
long epochSeconds = unixTimeNanosInt.divide(BigInteger.valueOf(NANOSECONDS_PER_SECOND)).longValue();
long nanosOfSecond = unixTimeNanosInt.remainder(BigInteger.valueOf(NANOSECONDS_PER_SECOND)).longValue();
@@ -167,27 +167,28 @@ public static LongTimestamp fromLong(@LiteralParameter("s") long scale, @SqlType
epochSeconds -= 1;
picosOfSecond += PICOSECONDS_PER_SECOND;
}
- return DateTimes.longTimestamp(epochSeconds, picosOfSecond);
+ return DateTimes.longTimestampWithTimeZone(epochSeconds, picosOfSecond, session.getTimeZoneKey().getZoneId());
}
@LiteralParameters({"p", "s"})
- @SqlType("timestamp(9)")
- public static LongTimestamp fromShort(@LiteralParameter("s") long scale, @SqlType("decimal(p, s)") long unixTimeNanos)
+ @SqlType("timestamp(9) with time zone")
+ public static LongTimestampWithTimeZone fromShort(@LiteralParameter("s") long scale, ConnectorSession session, @SqlType("decimal(p, s)") long unixTimeNanos)
{
+ // TODO (https://github.com/trinodb/trino/issues/5781)
long roundedUnixTimeNanos = MathFunctions.Round.roundShort(scale, unixTimeNanos);
- return fromUnixtimeNanosLong(roundedUnixTimeNanos);
+ return fromUnixtimeNanosLong(session, roundedUnixTimeNanos);
}
}
@ScalarFunction("from_unixtime_nanos")
- @SqlType("timestamp(9)")
- public static LongTimestamp fromUnixtimeNanosLong(@SqlType(StandardTypes.BIGINT) long unixTimeNanos)
+ @SqlType("timestamp(9) with time zone")
+ public static LongTimestampWithTimeZone fromUnixtimeNanosLong(ConnectorSession session, @SqlType(StandardTypes.BIGINT) long unixTimeNanos)
{
long epochSeconds = floorDiv(unixTimeNanos, NANOSECONDS_PER_SECOND);
long nanosOfSecond = floorMod(unixTimeNanos, NANOSECONDS_PER_SECOND);
long picosOfSecond = nanosOfSecond * PICOSECONDS_PER_NANOSECOND;
- return DateTimes.longTimestamp(epochSeconds, picosOfSecond);
+ return DateTimes.longTimestampWithTimeZone(epochSeconds, picosOfSecond, session.getTimeZoneKey().getZoneId());
}
@ScalarFunction("to_iso8601")
diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/ToUnixTime.java b/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/ToUnixTime.java
deleted file mode 100644
index 22db5bd4b81e..000000000000
--- a/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/ToUnixTime.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.trino.operator.scalar.timestamp;
-
-import io.trino.spi.function.LiteralParameters;
-import io.trino.spi.function.ScalarFunction;
-import io.trino.spi.function.SqlType;
-import io.trino.spi.type.LongTimestamp;
-import io.trino.spi.type.StandardTypes;
-
-import static io.trino.type.DateTimes.MICROSECONDS_PER_SECOND;
-import static io.trino.type.DateTimes.PICOSECONDS_PER_SECOND;
-
-@ScalarFunction("to_unixtime")
-public final class ToUnixTime
-{
- private ToUnixTime() {}
-
- @LiteralParameters("p")
- @SqlType(StandardTypes.DOUBLE)
- public static double toUnixTime(@SqlType("timestamp(p)") long timestamp)
- {
- return timestamp * 1.0 / MICROSECONDS_PER_SECOND;
- }
-
- @LiteralParameters("p")
- @SqlType(StandardTypes.DOUBLE)
- public static double toUnixTime(@SqlType("timestamp(p)") LongTimestamp timestamp)
- {
- return timestamp.getEpochMicros() * 1.0 / MICROSECONDS_PER_SECOND + timestamp.getPicosOfMicro() * 1.0 / PICOSECONDS_PER_SECOND;
- }
-}
diff --git a/core/trino-main/src/main/java/io/trino/server/PluginManager.java b/core/trino-main/src/main/java/io/trino/server/PluginManager.java
index 8fc144b40976..c04d601f84e1 100644
--- a/core/trino-main/src/main/java/io/trino/server/PluginManager.java
+++ b/core/trino-main/src/main/java/io/trino/server/PluginManager.java
@@ -42,6 +42,7 @@
import java.net.URL;
import java.util.List;
+import java.util.Optional;
import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
@@ -67,7 +68,7 @@ public class PluginManager
private final MetadataManager metadataManager;
private final ResourceGroupManager> resourceGroupManager;
private final AccessControlManager accessControlManager;
- private final PasswordAuthenticatorManager passwordAuthenticatorManager;
+ private final Optional passwordAuthenticatorManager;
private final CertificateAuthenticatorManager certificateAuthenticatorManager;
private final EventListenerManager eventListenerManager;
private final GroupProviderManager groupProviderManager;
@@ -82,7 +83,7 @@ public PluginManager(
MetadataManager metadataManager,
ResourceGroupManager> resourceGroupManager,
AccessControlManager accessControlManager,
- PasswordAuthenticatorManager passwordAuthenticatorManager,
+ Optional passwordAuthenticatorManager,
CertificateAuthenticatorManager certificateAuthenticatorManager,
EventListenerManager eventListenerManager,
GroupProviderManager groupProviderManager,
@@ -191,10 +192,12 @@ private void installPluginInternal(Plugin plugin, Supplier duplicat
accessControlManager.addSystemAccessControlFactory(accessControlFactory);
}
- for (PasswordAuthenticatorFactory authenticatorFactory : plugin.getPasswordAuthenticatorFactories()) {
- log.info("Registering password authenticator %s", authenticatorFactory.getName());
- passwordAuthenticatorManager.addPasswordAuthenticatorFactory(authenticatorFactory);
- }
+ passwordAuthenticatorManager.ifPresent(authenticationManager -> {
+ for (PasswordAuthenticatorFactory authenticatorFactory : plugin.getPasswordAuthenticatorFactories()) {
+ log.info("Registering password authenticator %s", authenticatorFactory.getName());
+ authenticationManager.addPasswordAuthenticatorFactory(authenticatorFactory);
+ }
+ });
for (CertificateAuthenticatorFactory authenticatorFactory : plugin.getCertificateAuthenticatorFactories()) {
log.info("Registering certificate authenticator %s", authenticatorFactory.getName());
diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java
index 559430890508..6bed0e26e676 100644
--- a/core/trino-main/src/main/java/io/trino/server/Server.java
+++ b/core/trino-main/src/main/java/io/trino/server/Server.java
@@ -17,7 +17,10 @@
import com.google.common.base.StandardSystemProperty;
import com.google.common.collect.ImmutableList;
import com.google.inject.Injector;
+import com.google.inject.Key;
import com.google.inject.Module;
+import com.google.inject.TypeLiteral;
+import com.google.inject.util.Types;
import io.airlift.bootstrap.ApplicationConfigurationException;
import io.airlift.bootstrap.Bootstrap;
import io.airlift.discovery.client.Announcer;
@@ -56,6 +59,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
+import java.util.Optional;
import java.util.Set;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
@@ -124,7 +128,8 @@ private void doStart(String trinoVersion)
injector.getInstance(SessionPropertyDefaults.class).loadConfigurationManager();
injector.getInstance(ResourceGroupManager.class).loadConfigurationManager();
injector.getInstance(AccessControlManager.class).loadSystemAccessControl();
- injector.getInstance(PasswordAuthenticatorManager.class).loadPasswordAuthenticator();
+ injector.getInstance(optionalKey(PasswordAuthenticatorManager.class))
+ .ifPresent(PasswordAuthenticatorManager::loadPasswordAuthenticator);
injector.getInstance(EventListenerManager.class).loadEventListeners();
injector.getInstance(GroupProviderManager.class).loadConfiguredGroupProvider();
injector.getInstance(CertificateAuthenticatorManager.class).loadCertificateAuthenticator();
@@ -152,6 +157,12 @@ private void doStart(String trinoVersion)
}
}
+ @SuppressWarnings("unchecked")
+ private static Key> optionalKey(Class type)
+ {
+ return Key.get((TypeLiteral>) TypeLiteral.get(Types.newParameterizedType(Optional.class, type)));
+ }
+
private static void addMessages(StringBuilder output, String type, List messages)
{
if (messages.isEmpty()) {
diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java
index 5f6dc43155e1..6318e5c67ba7 100644
--- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java
+++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java
@@ -159,6 +159,7 @@ public final class HttpRemoteTask
private final PartitionedSplitCountTracker partitionedSplitCountTracker;
+ private final AtomicBoolean started = new AtomicBoolean(false);
private final AtomicBoolean aborting = new AtomicBoolean(false);
public HttpRemoteTask(
@@ -318,6 +319,7 @@ public void start()
{
try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
// to start we just need to trigger an update
+ started.set(true);
scheduleUpdate();
dynamicFiltersFetcher.start();
@@ -493,9 +495,9 @@ private synchronized void processTaskUpdate(TaskInfo newValue, List
pendingSourceSplitCount -= removed;
}
}
- updateSplitQueueSpace();
-
+ // Update node level split tracker before split queue space to ensure it's up to date before waking up the scheduler
partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount());
+ updateSplitQueueSpace();
}
private void updateTaskInfo(TaskInfo taskInfo)
@@ -513,7 +515,7 @@ private synchronized void sendUpdate()
{
TaskStatus taskStatus = getTaskStatus();
// don't update if the task hasn't been started yet or if it is already finished
- if (!needsUpdate.get() || taskStatus.getState().isDone()) {
+ if (!started.get() || !needsUpdate.get() || taskStatus.getState().isDone()) {
return;
}
diff --git a/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java
index 6eef6513945d..99cbdd4aebf0 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java
@@ -43,24 +43,12 @@ protected AbstractBearerAuthenticator(String principalField, UserMapping userMap
public Identity authenticate(ContainerRequestContext request)
throws AuthenticationException
{
- List headers = request.getHeaders().get(AUTHORIZATION);
- if (headers == null || headers.size() == 0) {
- throw needAuthentication(request, null);
- }
- if (headers.size() > 1) {
- throw new IllegalArgumentException(format("Multiple %s headers detected: %s, where only single %s header is supported", AUTHORIZATION, headers, AUTHORIZATION));
- }
-
- String header = headers.get(0);
- int space = header.indexOf(' ');
- if ((space < 0) || !header.substring(0, space).equalsIgnoreCase("bearer")) {
- throw needAuthentication(request, null);
- }
- String token = header.substring(space + 1).trim();
- if (token.isEmpty()) {
- throw needAuthentication(request, null);
- }
+ return authenticate(request, extractToken(request));
+ }
+ public Identity authenticate(ContainerRequestContext request, String token)
+ throws AuthenticationException
+ {
try {
Jws claimsJws = parseClaimsJws(token);
String principal = claimsJws.getBody().get(principalField, String.class);
@@ -80,6 +68,29 @@ public Identity authenticate(ContainerRequestContext request)
}
}
+ public String extractToken(ContainerRequestContext request)
+ throws AuthenticationException
+ {
+ List headers = request.getHeaders().get(AUTHORIZATION);
+ if (headers == null || headers.size() == 0) {
+ throw needAuthentication(request, null);
+ }
+ if (headers.size() > 1) {
+ throw new IllegalArgumentException(format("Multiple %s headers detected: %s, where only single %s header is supported", AUTHORIZATION, headers, AUTHORIZATION));
+ }
+
+ String header = headers.get(0);
+ int space = header.indexOf(' ');
+ if ((space < 0) || !header.substring(0, space).equalsIgnoreCase("bearer")) {
+ throw needAuthentication(request, null);
+ }
+ String token = header.substring(space + 1).trim();
+ if (token.isEmpty()) {
+ throw needAuthentication(request, null);
+ }
+ return token;
+ }
+
protected abstract Jws parseClaimsJws(String jws);
protected abstract AuthenticationException needAuthentication(ContainerRequestContext request, String message);
diff --git a/core/trino-main/src/main/java/io/trino/server/security/AuthenticationFilter.java b/core/trino-main/src/main/java/io/trino/server/security/AuthenticationFilter.java
index 65d4eae0f16d..35f84ce16a9d 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/AuthenticationFilter.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/AuthenticationFilter.java
@@ -24,9 +24,11 @@
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
+import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
+import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.server.ServletSecurityUtils.sendWwwAuthenticate;
@@ -86,10 +88,17 @@ else if (insecureAuthenticationOverHttpAllowed) {
authenticatedIdentity = authenticator.authenticate(request);
}
catch (AuthenticationException e) {
- if (e.getMessage() != null) {
- messages.add(e.getMessage());
- }
- e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
+ // Some authenticators (e.g. password) nest multiple internal authenticators.
+ // Exceptions from additional failed login attempts are suppressed in the first exception
+ Stream.concat(Stream.of(e), Arrays.stream(e.getSuppressed()))
+ .filter(ex -> ex instanceof AuthenticationException)
+ .map(AuthenticationException.class::cast)
+ .forEach(ex -> {
+ if (ex.getMessage() != null) {
+ messages.add(ex.getMessage());
+ }
+ ex.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
+ });
continue;
}
diff --git a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java
index 95f780e7b9e7..bf4dd094b657 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java
@@ -21,6 +21,7 @@
import java.security.Principal;
+import static com.google.common.base.Verify.verify;
import static io.trino.server.security.BasicAuthCredentials.extractBasicAuthCredentials;
import static io.trino.server.security.UserMapping.createUserMapping;
import static java.util.Objects.requireNonNull;
@@ -47,22 +48,34 @@ public Identity authenticate(ContainerRequestContext request)
{
BasicAuthCredentials basicAuthCredentials = extractBasicAuthCredentials(request)
.orElseThrow(() -> needAuthentication(null));
- try {
- Principal principal = authenticatorManager.getAuthenticator().createAuthenticatedPrincipal(
- basicAuthCredentials.getUser(),
- basicAuthCredentials.getPassword()
- .orElseThrow(() -> new AuthenticationException("Malformed credentials: password is empty")));
- String authenticatedUser = userMapping.mapUser(principal.toString());
- return Identity.forUser(authenticatedUser)
- .withPrincipal(principal)
- .build();
- }
- catch (UserMappingException | AccessDeniedException e) {
- throw needAuthentication(e.getMessage());
- }
- catch (RuntimeException e) {
- throw new RuntimeException("Authentication error", e);
+ String user = basicAuthCredentials.getUser();
+ String password = basicAuthCredentials.getPassword()
+ .orElseThrow(() -> new AuthenticationException("Malformed credentials: password is empty"));
+
+ AuthenticationException exception = null;
+ for (io.trino.spi.security.PasswordAuthenticator authenticator : authenticatorManager.getAuthenticators()) {
+ try {
+ Principal principal = authenticator.createAuthenticatedPrincipal(user, password);
+ String authenticatedUser = userMapping.mapUser(principal.toString());
+ return Identity.forUser(authenticatedUser)
+ .withPrincipal(principal)
+ .build();
+ }
+ catch (UserMappingException | AccessDeniedException e) {
+ if (exception == null) {
+ exception = needAuthentication(e.getMessage());
+ }
+ else {
+ exception.addSuppressed(needAuthentication(e.getMessage()));
+ }
+ }
+ catch (RuntimeException e) {
+ throw new RuntimeException("Authentication error", e);
+ }
}
+
+ verify(exception != null, "exception not set");
+ throw exception;
}
private static AuthenticationException needAuthentication(String message)
diff --git a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorConfig.java b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorConfig.java
index 1ed99ecc2601..6fb163eb1575 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorConfig.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorConfig.java
@@ -13,16 +13,27 @@
*/
package io.trino.server.security;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
import io.airlift.configuration.Config;
+import io.airlift.configuration.ConfigDescription;
import io.airlift.configuration.validation.FileExists;
+import javax.validation.constraints.NotEmpty;
+import javax.validation.constraints.NotNull;
+
import java.io.File;
+import java.util.List;
import java.util.Optional;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+
public class PasswordAuthenticatorConfig
{
+ private static final Splitter SPLITTER = Splitter.on(',').trimResults().omitEmptyStrings();
private Optional userMappingPattern = Optional.empty();
private Optional userMappingFile = Optional.empty();
+ private List passwordAuthenticatorFiles = ImmutableList.of(new File("etc/password-authenticator.properties"));
public Optional getUserMappingPattern()
{
@@ -47,4 +58,21 @@ public PasswordAuthenticatorConfig setUserMappingFile(File userMappingFile)
this.userMappingFile = Optional.ofNullable(userMappingFile);
return this;
}
+
+ @NotNull
+ @NotEmpty(message = "At least one password authenticator config file is required")
+ public List<@FileExists File> getPasswordAuthenticatorFiles()
+ {
+ return passwordAuthenticatorFiles;
+ }
+
+ @Config("password-authenticator.config-files")
+ @ConfigDescription("Ordered list of password authenticator config files")
+ public PasswordAuthenticatorConfig setPasswordAuthenticatorFiles(String passwordAuthenticatorFiles)
+ {
+ this.passwordAuthenticatorFiles = SPLITTER.splitToList(passwordAuthenticatorFiles).stream()
+ .map(File::new)
+ .collect(toImmutableList());
+ return this;
+ }
}
diff --git a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorManager.java b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorManager.java
index b11c39b94b51..20575ebfb14c 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorManager.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticatorManager.java
@@ -14,13 +14,18 @@
package io.trino.server.security;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.trino.spi.security.PasswordAuthenticator;
import io.trino.spi.security.PasswordAuthenticatorFactory;
import java.io.File;
+import java.io.IOException;
+import java.io.UncheckedIOException;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -36,12 +41,20 @@ public class PasswordAuthenticatorManager
{
private static final Logger log = Logger.get(PasswordAuthenticatorManager.class);
- private static final File CONFIG_FILE = new File("etc/password-authenticator.properties");
private static final String NAME_PROPERTY = "password-authenticator.name";
+ private final List configFiles;
private final AtomicBoolean required = new AtomicBoolean();
private final Map factories = new ConcurrentHashMap<>();
- private final AtomicReference authenticator = new AtomicReference<>();
+ private final AtomicReference> authenticators = new AtomicReference<>();
+
+ @Inject
+ public PasswordAuthenticatorManager(PasswordAuthenticatorConfig config)
+ {
+ requireNonNull(config, "config is null");
+ this.configFiles = ImmutableList.copyOf(config.getPasswordAuthenticatorFiles());
+ checkArgument(!configFiles.isEmpty(), "password authenticator files list is empty");
+ }
public void setRequired()
{
@@ -56,18 +69,31 @@ public void addPasswordAuthenticatorFactory(PasswordAuthenticatorFactory factory
public boolean isLoaded()
{
- return authenticator.get() != null;
+ return authenticators.get() != null;
}
public void loadPasswordAuthenticator()
- throws Exception
{
if (!required.get()) {
return;
}
- File configFile = CONFIG_FILE.getAbsoluteFile();
- Map properties = new HashMap<>(loadPropertiesFrom(configFile.getPath()));
+ ImmutableList.Builder authenticators = ImmutableList.builder();
+ for (File configFile : configFiles) {
+ authenticators.add(loadAuthenticator(configFile.getAbsoluteFile()));
+ }
+ this.authenticators.set(authenticators.build());
+ }
+
+ private PasswordAuthenticator loadAuthenticator(File configFile)
+ {
+ Map properties;
+ try {
+ properties = new HashMap<>(loadPropertiesFrom(configFile.getPath()));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
String name = properties.remove(NAME_PROPERTY);
checkState(!isNullOrEmpty(name), "Password authenticator configuration %s does not contain '%s'", configFile, NAME_PROPERTY);
@@ -77,23 +103,21 @@ public void loadPasswordAuthenticator()
PasswordAuthenticatorFactory factory = factories.get(name);
checkState(factory != null, "Password authenticator '%s' is not registered", name);
- PasswordAuthenticator authenticator = factory.create(ImmutableMap.copyOf(properties));
- this.authenticator.set(requireNonNull(authenticator, "authenticator is null"));
-
log.info("-- Loaded password authenticator %s --", name);
+ return factory.create(ImmutableMap.copyOf(properties));
}
- public PasswordAuthenticator getAuthenticator()
+ public List getAuthenticators()
{
- checkState(isLoaded(), "authenticator was not loaded");
- return authenticator.get();
+ checkState(isLoaded(), "authenticators were not loaded");
+ return authenticators.get();
}
@VisibleForTesting
- public void setAuthenticator(PasswordAuthenticator authenticator)
+ public void setAuthenticators(PasswordAuthenticator... authenticators)
{
- if (!this.authenticator.compareAndSet(null, authenticator)) {
- throw new IllegalStateException("authenticator already loaded");
+ if (!this.authenticators.compareAndSet(null, ImmutableList.copyOf(authenticators))) {
+ throw new IllegalStateException("authenticators already loaded");
}
}
}
diff --git a/core/trino-main/src/main/java/io/trino/server/security/ServerSecurityModule.java b/core/trino-main/src/main/java/io/trino/server/security/ServerSecurityModule.java
index ba729ef62eea..b2539d3247ae 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/ServerSecurityModule.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/ServerSecurityModule.java
@@ -61,7 +61,7 @@ protected void setup(Binder binder)
.internalOnlyResource(DynamicAnnouncementResource.class)
.internalOnlyResource(StoreResource.class);
- binder.bind(PasswordAuthenticatorManager.class).in(Scopes.SINGLETON);
+ newOptionalBinder(binder, PasswordAuthenticatorManager.class);
binder.bind(CertificateAuthenticatorManager.class).in(Scopes.SINGLETON);
insecureHttpAuthenticationDefaults();
@@ -73,7 +73,10 @@ protected void setup(Binder binder)
configBinder(certificateBinder).bindConfig(CertificateConfig.class);
}));
installAuthenticator("kerberos", KerberosAuthenticator.class, KerberosConfig.class);
- installAuthenticator("password", PasswordAuthenticator.class, PasswordAuthenticatorConfig.class);
+ install(authenticatorModule("password", PasswordAuthenticator.class, used -> {
+ configBinder(binder).bindConfig(PasswordAuthenticatorConfig.class);
+ binder.bind(PasswordAuthenticatorManager.class).in(Scopes.SINGLETON);
+ }));
install(authenticatorModule("jwt", JwtAuthenticator.class, new JwtAuthenticatorSupportModule()));
install(authenticatorModule("oauth2", OAuth2Authenticator.class, new OAuth2AuthenticationSupportModule()));
diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2CallbackResource.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2CallbackResource.java
index 6ceb412b91aa..be65f8e6e3c9 100644
--- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2CallbackResource.java
+++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2CallbackResource.java
@@ -88,20 +88,13 @@ public Response callback(
// Note: the Web UI may be disabled, so REST requests can not redirect to a success or error page inside of the Web UI
if (error != null) {
- LOG.debug(
- "OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s",
- error,
- errorDescription,
- errorUri,
- state);
-
- passErrorToTokenExchange(
- authId,
- "OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s",
- error,
- errorDescription,
- errorUri,
- state);
+ LOG.debug("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state);
+
+ if (tokenExchange.isPresent() && authId.isPresent()) {
+ tokenExchange.get().setTokenExchangeError(
+ authId.get(),
+ format("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state));
+ }
return Response.ok()
.entity(service.getCallbackErrorHtml(error))
.build();
@@ -117,8 +110,9 @@ public Response callback(
}
catch (ChallengeFailedException | RuntimeException e) {
LOG.debug(e, "Authentication response could not be verified: state=%s", state);
-
- passErrorToTokenExchange(authId, "Authentication response could not be verified: state=%s", state);
+ if (tokenExchange.isPresent() && authId.isPresent()) {
+ tokenExchange.get().setTokenExchangeError(authId.get(), format("Authentication response could not be verified: state=%s", state));
+ }
return Response.ok()
.entity(service.getInternalFailureHtml("Authentication response could not be verified"))
.build();
@@ -146,12 +140,4 @@ public Response callback(
}
return builder.build();
}
-
- private void passErrorToTokenExchange(Optional authId, String format, String... args)
- {
- if (tokenExchange.isEmpty() || authId.isEmpty()) {
- return;
- }
- tokenExchange.orElseThrow().setTokenExchangeError(authId.orElseThrow(), format(format, args));
- }
}
diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java
index 41bf994aea11..42bb7fb6cd6b 100644
--- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java
+++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java
@@ -69,6 +69,7 @@
import io.trino.server.security.ServerSecurityModule;
import io.trino.spi.Plugin;
import io.trino.spi.QueryId;
+import io.trino.spi.eventlistener.EventListener;
import io.trino.spi.security.GroupProvider;
import io.trino.spi.security.SystemAccessControl;
import io.trino.split.PageSourceManager;
@@ -194,7 +195,8 @@ private TestingTrinoServer(
Optional discoveryUri,
Module additionalModule,
Optional baseDataDir,
- List systemAccessControls)
+ List systemAccessControls,
+ List eventListeners)
{
this.coordinator = coordinator;
@@ -317,6 +319,9 @@ private TestingTrinoServer(
accessControl.setSystemAccessControls(systemAccessControls);
+ EventListenerManager eventListenerManager = injector.getInstance(EventListenerManager.class);
+ eventListeners.forEach(eventListenerManager::addEventListener);
+
announcer.forceAnnounce();
refreshNodes();
@@ -593,6 +598,7 @@ public static class Builder
private Module additionalModule = EMPTY_MODULE;
private Optional baseDataDir = Optional.empty();
private List systemAccessControls = ImmutableList.of();
+ private List eventListeners = ImmutableList.of();
public Builder setCoordinator(boolean coordinator)
{
@@ -636,6 +642,12 @@ public Builder setSystemAccessControls(List systemAccessCon
return this;
}
+ public Builder setEventListeners(List eventListeners)
+ {
+ this.eventListeners = ImmutableList.copyOf(requireNonNull(eventListeners, "eventListeners is null"));
+ return this;
+ }
+
public TestingTrinoServer build()
{
return new TestingTrinoServer(
@@ -645,7 +657,8 @@ public TestingTrinoServer build()
discoveryUri,
additionalModule,
baseDataDir,
- systemAccessControls);
+ systemAccessControls,
+ eventListeners);
}
}
}
diff --git a/core/trino-main/src/main/java/io/trino/server/ui/FormUiAuthenticatorModule.java b/core/trino-main/src/main/java/io/trino/server/ui/FormUiAuthenticatorModule.java
index 7b0b208a2fec..cf7f6cff65bb 100644
--- a/core/trino-main/src/main/java/io/trino/server/ui/FormUiAuthenticatorModule.java
+++ b/core/trino-main/src/main/java/io/trino/server/ui/FormUiAuthenticatorModule.java
@@ -17,6 +17,7 @@
import com.google.inject.Key;
import com.google.inject.Module;
import io.trino.server.security.Authenticator;
+import io.trino.server.security.PasswordAuthenticatorConfig;
import io.trino.server.security.PasswordAuthenticatorManager;
import static com.google.inject.Scopes.SINGLETON;
@@ -37,10 +38,11 @@ public FormUiAuthenticatorModule(boolean usePasswordManager)
@Override
public void configure(Binder binder)
{
- binder.bind(PasswordAuthenticatorManager.class).in(SINGLETON);
binder.bind(FormWebUiAuthenticationFilter.class).in(SINGLETON);
binder.bind(WebUiAuthenticationFilter.class).to(FormWebUiAuthenticationFilter.class).in(SINGLETON);
if (usePasswordManager) {
+ binder.bind(PasswordAuthenticatorManager.class).in(SINGLETON);
+ configBinder(binder).bindConfig(PasswordAuthenticatorConfig.class);
binder.bind(FormAuthenticator.class).to(PasswordManagerFormAuthenticator.class).in(SINGLETON);
}
else {
diff --git a/core/trino-main/src/main/java/io/trino/server/ui/PasswordManagerFormAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/ui/PasswordManagerFormAuthenticator.java
index be068ea58e70..6bea02bb9831 100644
--- a/core/trino-main/src/main/java/io/trino/server/ui/PasswordManagerFormAuthenticator.java
+++ b/core/trino-main/src/main/java/io/trino/server/ui/PasswordManagerFormAuthenticator.java
@@ -21,6 +21,8 @@
import javax.inject.Inject;
+import java.util.List;
+
import static java.util.Objects.requireNonNull;
public class PasswordManagerFormAuthenticator
@@ -66,17 +68,20 @@ public boolean isValidCredential(String username, String password, boolean secur
return insecureAuthenticationOverHttpAllowed && password == null;
}
- PasswordAuthenticator authenticator = passwordAuthenticatorManager.getAuthenticator();
- try {
- authenticator.createAuthenticatedPrincipal(username, password);
- return true;
- }
- catch (AccessDeniedException e) {
- return false;
- }
- catch (RuntimeException e) {
- log.debug(e, "Error authenticating user for Web UI");
- return false;
+ List authenticators = passwordAuthenticatorManager.getAuthenticators();
+ for (PasswordAuthenticator authenticator : authenticators) {
+ try {
+ authenticator.createAuthenticatedPrincipal(username, password);
+ return true;
+ }
+ catch (AccessDeniedException e) {
+ // Try another one
+ }
+ catch (RuntimeException e) {
+ log.debug(e, "Error authenticating user for Web UI");
+ }
}
+
+ return false;
}
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
index 46e28434568c..94f96381c0dd 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
@@ -13,11 +13,13 @@
*/
package io.trino.sql.analyzer;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ArrayListMultimap;
-import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
@@ -32,6 +34,7 @@
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorTableMetadata;
+import io.trino.spi.eventlistener.ColumnDetail;
import io.trino.spi.eventlistener.ColumnInfo;
import io.trino.spi.eventlistener.RoutineInfo;
import io.trino.spi.eventlistener.TableInfo;
@@ -121,9 +124,6 @@ public class Analysis
// a map of users to the columns per table that they access
private final Map>> tableColumnReferences = new LinkedHashMap<>();
- // Track referenced fields from source relation node
- private final Multimap, Field> referencedFields = HashMultimap.create();
-
private final Map, List> aggregates = new LinkedHashMap<>();
private final Map, List> orderByAggregates = new LinkedHashMap<>();
private final Map, GroupingSetAnalysis> groupingSets = new LinkedHashMap<>();
@@ -194,6 +194,8 @@ public class Analysis
// row id field for update/delete queries
private final Map, FieldReference> rowIdField = new LinkedHashMap<>();
+ private final Multimap originColumnDetails = ArrayListMultimap.create();
+ private final Multimap, Field> fieldLineage = ArrayListMultimap.create();
public Analysis(@Nullable Statement root, Map, Expression> parameters, boolean isDescribe)
{
@@ -216,14 +218,14 @@ public Optional getTarget()
{
return target.map(target -> {
QualifiedObjectName name = target.getName();
- return new Output(name.getCatalogName(), name.getSchemaName(), name.getObjectName());
+ return new Output(name.getCatalogName(), name.getSchemaName(), name.getObjectName(), target.getColumns());
});
}
- public void setUpdateType(String updateType, QualifiedObjectName targetName, Optional targetTable)
+ public void setUpdateType(String updateType, QualifiedObjectName targetName, Optional targetTable, Optional> targetColumns)
{
this.updateType = updateType;
- this.target = Optional.of(new UpdateTarget(targetName, targetTable));
+ this.target = Optional.of(new UpdateTarget(targetName, targetTable, targetColumns));
}
public void resetUpdateType()
@@ -850,11 +852,6 @@ public void addEmptyColumnReferencesForTable(AccessControl accessControl, Identi
tableColumnReferences.computeIfAbsent(accessControlInfo, k -> new LinkedHashMap<>()).computeIfAbsent(table, k -> new HashSet<>());
}
- public void addReferencedFields(Multimap, Field> references)
- {
- referencedFields.putAll(references);
- }
-
public Map>> getTableColumnReferences()
{
return tableColumnReferences;
@@ -965,6 +962,28 @@ public List getRoutines()
.collect(toImmutableList());
}
+ public void addSourceColumns(Field field, Set sourceColumn)
+ {
+ originColumnDetails.putAll(field, sourceColumn);
+ }
+
+ public Set getSourceColumns(Field field)
+ {
+ return ImmutableSet.copyOf(originColumnDetails.get(field));
+ }
+
+ public void addExpressionFields(Expression expression, Collection fields)
+ {
+ fieldLineage.putAll(NodeRef.of(expression), fields);
+ }
+
+ public Set getExpressionSourceColumns(Expression expression)
+ {
+ return fieldLineage.get(NodeRef.of(expression)).stream()
+ .flatMap(field -> getSourceColumns(field).stream())
+ .collect(toImmutableSet());
+ }
+
public void setRowIdField(Table table, FieldReference field)
{
rowIdField.put(NodeRef.of(table), field);
@@ -1487,6 +1506,56 @@ public Scope getAccessControlScope()
}
}
+ public static class SourceColumn
+ {
+ private final QualifiedObjectName tableName;
+ private final String columnName;
+
+ @JsonCreator
+ public SourceColumn(@JsonProperty("tableName") QualifiedObjectName tableName, @JsonProperty("columnName") String columnName)
+ {
+ this.tableName = requireNonNull(tableName, "tableName is null");
+ this.columnName = requireNonNull(columnName, "columnName is null");
+ }
+
+ @JsonProperty
+ public QualifiedObjectName getTableName()
+ {
+ return tableName;
+ }
+
+ @JsonProperty
+ public String getColumnName()
+ {
+ return columnName;
+ }
+
+ public ColumnDetail getColumnDetail()
+ {
+ return new ColumnDetail(tableName.getCatalogName(), tableName.getSchemaName(), tableName.getObjectName(), columnName);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(tableName, columnName);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (obj == this) {
+ return true;
+ }
+ if ((obj == null) || (getClass() != obj.getClass())) {
+ return false;
+ }
+ SourceColumn entry = (SourceColumn) obj;
+ return Objects.equals(tableName, entry.tableName) &&
+ Objects.equals(columnName, entry.columnName);
+ }
+ }
+
private static class RoutineEntry
{
private final ResolvedFunction function;
@@ -1513,11 +1582,13 @@ private static class UpdateTarget
{
private final QualifiedObjectName name;
private final Optional table;
+ private final Optional> columns;
- public UpdateTarget(QualifiedObjectName name, Optional table)
+ public UpdateTarget(QualifiedObjectName name, Optional table, Optional> columns)
{
this.name = requireNonNull(name, "name is null");
this.table = requireNonNull(table, "table is null");
+ this.columns = requireNonNull(columns, "columns is null").map(ImmutableList::copyOf);
}
public QualifiedObjectName getName()
@@ -1529,5 +1600,10 @@ public Optional getTable()
{
return table;
}
+
+ public Optional> getColumns()
+ {
+ return columns;
+ }
}
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java
index a73d722d24d5..2eba8f272590 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java
@@ -256,6 +256,7 @@ public class ExpressionAnalyzer
private final CorrelationSupport correlationSupport;
private final Function getPreanalyzedType;
private final Function getResolvedWindow;
+ private final List sourceFields = new ArrayList<>();
public ExpressionAnalyzer(
Metadata metadata,
@@ -401,6 +402,11 @@ public Multimap, Field> getReferencedFields()
return referencedFields;
}
+ public List getSourceFields()
+ {
+ return sourceFields;
+ }
+
private class Visitor
extends StackableAstVisitor
{
@@ -507,6 +513,8 @@ private Type handleResolvedField(Expression node, ResolvedField resolvedField, S
tableColumnReferences.put(field.getOriginTable().get(), field.getOriginColumnName().get());
}
+ sourceFields.add(field);
+
fieldId.getRelationId()
.getSourceNode()
.ifPresent(source -> referencedFields.put(NodeRef.of(source), field));
@@ -1571,6 +1579,8 @@ else if (previousNode instanceof QuantifiedComparisonExpression) {
scalarSubqueries.add(NodeRef.of(node));
}
+ sourceFields.add(queryScope.getRelationType().getFieldByIndex(0));
+
Type type = getOnlyElement(queryScope.getRelationType().getVisibleFields()).getType();
return setExpressionType(node, type);
}
@@ -1973,6 +1983,7 @@ public static ExpressionAnalysis analyzeExpression(
analyzer.analyze(expression, scope);
updateAnalysis(analysis, analyzer, session, accessControl);
+ analysis.addExpressionFields(expression, analyzer.getSourceFields());
return new ExpressionAnalysis(
analyzer.getExpressionTypes(),
@@ -2030,7 +2041,6 @@ private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyze
analysis.addColumnReferences(analyzer.getColumnReferences());
analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences());
analysis.addTableColumnReferences(accessControl, session.getIdentity(), analyzer.getTableColumnReferences());
- analysis.addReferencedFields(analyzer.getReferencedFields());
}
public static ExpressionAnalyzer create(
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java
index 8f23911c8169..b61d9a716e0a 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java
@@ -906,6 +906,7 @@ public int getMaxRecursionDepth()
}
@Config("max-recursion-depth")
+ @ConfigDescription("Maximum recursion depth for recursive common table expression")
public FeaturesConfig setMaxRecursionDepth(int maxRecursionDepth)
{
this.maxRecursionDepth = maxRecursionDepth;
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Output.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Output.java
index 91def9382de8..abb4adcea85b 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Output.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Output.java
@@ -15,10 +15,13 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.collect.ImmutableList;
import javax.annotation.concurrent.Immutable;
+import java.util.List;
import java.util.Objects;
+import java.util.Optional;
import static java.util.Objects.requireNonNull;
@@ -28,16 +31,19 @@ public final class Output
private final String catalogName;
private final String schema;
private final String table;
+ private final Optional> columns;
@JsonCreator
public Output(
@JsonProperty("catalogName") String catalogName,
@JsonProperty("schema") String schema,
- @JsonProperty("table") String table)
+ @JsonProperty("table") String table,
+ @JsonProperty("columns") Optional> columns)
{
this.catalogName = requireNonNull(catalogName, "catalogName is null");
this.schema = requireNonNull(schema, "schema is null");
this.table = requireNonNull(table, "table is null");
+ this.columns = requireNonNull(columns, "columns is null").map(ImmutableList::copyOf);
}
@JsonProperty
@@ -58,6 +64,12 @@ public String getTable()
return table;
}
+ @JsonProperty
+ public Optional> getColumns()
+ {
+ return columns;
+ }
+
@Override
public boolean equals(Object o)
{
@@ -70,12 +82,13 @@ public boolean equals(Object o)
Output output = (Output) o;
return Objects.equals(catalogName, output.catalogName) &&
Objects.equals(schema, output.schema) &&
- Objects.equals(table, output.table);
+ Objects.equals(table, output.table) &&
+ Objects.equals(columns, output.columns);
}
@Override
public int hashCode()
{
- return Objects.hash(catalogName, schema, table);
+ return Objects.hash(catalogName, schema, table, columns);
}
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/OutputColumn.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/OutputColumn.java
new file mode 100644
index 000000000000..7682e79be713
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/OutputColumn.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.sql.analyzer;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.collect.ImmutableSet;
+import io.trino.execution.Column;
+import io.trino.sql.analyzer.Analysis.SourceColumn;
+
+import javax.annotation.concurrent.Immutable;
+
+import java.util.Objects;
+import java.util.Set;
+
+import static java.util.Objects.requireNonNull;
+
+@Immutable
+public final class OutputColumn
+{
+ private final Column column;
+ private final Set sourceColumns;
+
+ @JsonCreator
+ public OutputColumn(@JsonProperty("column") Column column, @JsonProperty("sourceColumns") Set sourceColumns)
+ {
+ this.column = requireNonNull(column, "column is null");
+ this.sourceColumns = ImmutableSet.copyOf(requireNonNull(sourceColumns, "sourceColumns is null"));
+ }
+
+ @JsonProperty
+ public Column getColumn()
+ {
+ return column;
+ }
+
+ @JsonProperty
+ public Set getSourceColumns()
+ {
+ return sourceColumns;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(column, sourceColumns);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (obj == this) {
+ return true;
+ }
+ if ((obj == null) || (getClass() != obj.getClass())) {
+ return false;
+ }
+ OutputColumn entry = (OutputColumn) obj;
+ return Objects.equals(column, entry.column) &&
+ Objects.equals(sourceColumns, entry.sourceColumns);
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java
index 6f7945018479..a5251e702a74 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java
@@ -20,8 +20,10 @@
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
+import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.connector.CatalogName;
+import io.trino.execution.Column;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionKind;
import io.trino.metadata.FunctionMetadata;
@@ -32,6 +34,7 @@
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TableHandle;
import io.trino.metadata.TableMetadata;
+import io.trino.metadata.TableSchema;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.security.ViewAccessControl;
@@ -40,6 +43,7 @@
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
+import io.trino.spi.connector.ColumnSchema;
import io.trino.spi.connector.ConnectorMaterializedViewDefinition;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.ConnectorViewDefinition;
@@ -61,6 +65,7 @@
import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis;
import io.trino.sql.analyzer.Analysis.ResolvedWindow;
import io.trino.sql.analyzer.Analysis.SelectExpression;
+import io.trino.sql.analyzer.Analysis.SourceColumn;
import io.trino.sql.analyzer.Analysis.UnnestAnalysis;
import io.trino.sql.analyzer.Scope.AsteriskedIdentifierChainBasis;
import io.trino.sql.parser.ParsingException;
@@ -397,8 +402,6 @@ protected Scope visitInsert(Insert insert, Optional scope)
// analyze the query that creates the data
Scope queryScope = analyze(insert.getQuery(), createScope(scope));
- analysis.setUpdateType("INSERT", targetTable, Optional.empty());
-
// verify the insert destination columns match the query
Optional targetTableHandle = metadata.getTableHandle(session, targetTable);
if (targetTableHandle.isEmpty()) {
@@ -410,20 +413,20 @@ protected Scope visitInsert(Insert insert, Optional scope)
throw semanticException(NOT_SUPPORTED, insert, "Insert into table with a row filter is not supported");
}
- TableMetadata tableMetadata = metadata.getTableMetadata(session, targetTableHandle.get());
+ TableSchema tableSchema = metadata.getTableSchema(session, targetTableHandle.get());
- List columns = tableMetadata.getColumns().stream()
+ List columns = tableSchema.getColumns().stream()
.filter(column -> !column.isHidden())
.collect(toImmutableList());
- for (ColumnMetadata column : columns) {
+ for (ColumnSchema column : columns) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) {
throw semanticException(NOT_SUPPORTED, insert, "Insert into table with column masks is not supported");
}
}
List tableColumns = columns.stream()
- .map(ColumnMetadata::getName)
+ .map(ColumnSchema::getName)
.collect(toImmutableList());
// analyze target table layout, table columns should contain all partition columns
@@ -462,7 +465,7 @@ protected Scope visitInsert(Insert insert, Optional scope)
newTableLayout));
List tableTypes = insertColumns.stream()
- .map(insertColumn -> tableMetadata.getColumn(insertColumn).getType())
+ .map(insertColumn -> tableSchema.getColumn(insertColumn).getType())
.collect(toImmutableList());
List queryTypes = queryScope.getRelationType().getVisibleFields().stream()
@@ -477,6 +480,22 @@ protected Scope visitInsert(Insert insert, Optional scope)
Joiner.on(", ").join(queryTypes));
}
+ Stream columnStream = Streams.zip(
+ insertColumns.stream(),
+ tableTypes.stream()
+ .map(Type::toString),
+ Column::new);
+
+ analysis.setUpdateType(
+ "INSERT",
+ targetTable,
+ Optional.empty(),
+ Optional.of(Streams.zip(
+ columnStream,
+ queryScope.getRelationType().getVisibleFields().stream(),
+ (column, field) -> new OutputColumn(column, analysis.getSourceColumns(field)))
+ .collect(toImmutableList())));
+
return createAndAssignScope(insert, scope, Field.newUnqualified("rows", BIGINT));
}
@@ -490,7 +509,7 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate
throw semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Materialized view '%s' does not exist", name);
}
- Optional storageName = getMaterializedViewStorageTableName(name);
+ Optional storageName = getMaterializedViewStorageTableName(optionalView.get(), name);
if (storageName.isEmpty()) {
throw semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Storage Table '%s' for materialized view '%s' does not exist", storageName, name);
@@ -502,20 +521,13 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate
Query query = parseView(optionalView.get().getOriginalSql(), name, refreshMaterializedView);
Scope queryScope = process(query, scope);
- analysis.setUpdateType("REFRESH MATERIALIZED VIEW", targetTable, Optional.empty());
-
// verify the insert destination columns match the query
Optional targetTableHandle = metadata.getTableHandle(session, targetTable);
if (targetTableHandle.isEmpty()) {
throw semanticException(TABLE_NOT_FOUND, refreshMaterializedView, "Table '%s' does not exist", targetTable);
}
- if (targetTableHandle.isPresent() && metadata.getMaterializedViewFreshness(session, name).isMaterializedViewFresh()) {
- analysis.setSkipMaterializedViewRefresh(true);
- }
- else {
- analysis.setSkipMaterializedViewRefresh(false);
- }
+ analysis.setSkipMaterializedViewRefresh(metadata.getMaterializedViewFreshness(session, name).isMaterializedViewFresh());
TableMetadata tableMetadata = metadata.getTableMetadata(session, targetTableHandle.get());
List insertColumns = tableMetadata.getColumns().stream()
@@ -545,6 +557,22 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate
"Query: [" + Joiner.on(", ").join(queryTypes) + "]");
}
+ Stream columns = Streams.zip(
+ insertColumns.stream(),
+ tableTypes.stream()
+ .map(Type::toString),
+ Column::new);
+
+ analysis.setUpdateType(
+ "REFRESH MATERIALIZED VIEW",
+ targetTable,
+ Optional.empty(),
+ Optional.of(Streams.zip(
+ columns,
+ queryScope.getRelationType().getVisibleFields().stream(),
+ (column, field) -> new OutputColumn(column, analysis.getSourceColumns(field)))
+ .collect(toImmutableList())));
+
return createAndAssignScope(refreshMaterializedView, scope, Field.newUnqualified("rows", BIGINT));
}
@@ -642,7 +670,7 @@ protected Scope visitDelete(Delete node, Optional scope)
Scope tableScope = analyzer.analyzeForUpdate(table, scope, UpdateKind.DELETE);
node.getWhere().ifPresent(where -> analyzeWhere(node, tableScope, where));
- analysis.setUpdateType("DELETE", tableName, Optional.of(table));
+ analysis.setUpdateType("DELETE", tableName, Optional.of(table), Optional.empty());
return createAndAssignScope(node, scope, Field.newUnqualified("rows", BIGINT));
}
@@ -651,7 +679,7 @@ protected Scope visitDelete(Delete node, Optional scope)
protected Scope visitAnalyze(Analyze node, Optional scope)
{
QualifiedObjectName tableName = createQualifiedObjectName(session, node, node.getTableName());
- analysis.setUpdateType("ANALYZE", tableName, Optional.empty());
+ analysis.setUpdateType("ANALYZE", tableName, Optional.empty(), Optional.empty());
// verify the target table exists and it's not a view
if (metadata.getView(session, tableName).isPresent()) {
@@ -696,7 +724,6 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional targetTableHandle = metadata.getTableHandle(session, targetTable);
if (targetTableHandle.isPresent()) {
@@ -707,6 +734,7 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional columns = ImmutableList.builder();
// analyze target table columns and column aliases
+ ImmutableList.Builder outputColumns = ImmutableList.builder();
if (node.getColumnAliases().isPresent()) {
validateColumnAliases(node.getColumnAliases().get(), queryScope.getRelationType().getVisibleFieldCount());
@@ -730,7 +759,9 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional new ColumnMetadata(field.getName().get(), field.getType()))
.collect(toImmutableList()));
+ queryScope.getRelationType().getVisibleFields().stream()
+ .map(this::createOutputColumn)
+ .forEach(outputColumns::add);
}
// create target table metadata
@@ -783,6 +817,12 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional scope)
{
QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName());
- analysis.setUpdateType("CREATE VIEW", viewName, Optional.empty());
// analyze the query that creates the view
StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, groupProvider, accessControl, session, warningCollector, CorrelationSupport.ALLOWED);
@@ -801,6 +840,14 @@ protected Scope visitCreateView(CreateView node, Optional scope)
validateColumns(node, queryScope.getRelationType());
+ analysis.setUpdateType(
+ "CREATE VIEW",
+ viewName,
+ Optional.empty(),
+ Optional.of(queryScope.getRelationType().getVisibleFields().stream()
+ .map(this::createOutputColumn)
+ .collect(toImmutableList())));
+
return createAndAssignScope(node, scope);
}
@@ -975,7 +1022,6 @@ protected Scope visitCall(Call node, Optional scope)
protected Scope visitCreateMaterializedView(CreateMaterializedView node, Optional scope)
{
QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName());
- analysis.setUpdateType("CREATE MATERIALIZED VIEW", viewName, Optional.empty());
if (node.isReplace() && node.isNotExists()) {
throw semanticException(NOT_SUPPORTED, node, "'CREATE OR REPLACE' and 'IF NOT EXISTS' clauses can not be used together");
@@ -994,6 +1040,15 @@ protected Scope visitCreateMaterializedView(CreateMaterializedView node, Optiona
validateColumns(node, queryScope.getRelationType());
+ analysis.setUpdateType(
+ "CREATE MATERIALIZED VIEW",
+ viewName,
+ Optional.empty(),
+ Optional.of(
+ queryScope.getRelationType().getVisibleFields().stream()
+ .map(this::createOutputColumn)
+ .collect(toImmutableList())));
+
return createAndAssignScope(node, scope);
}
@@ -1119,7 +1174,7 @@ protected Scope visitQuery(Query node, Optional scope)
@Override
protected Scope visitUnnest(Unnest node, Optional scope)
{
- ImmutableMap.Builder, List> mappings = ImmutableMap., List>builder();
+ ImmutableMap.Builder, List> mappings = ImmutableMap.builder();
ImmutableList.Builder outputFields = ImmutableList.builder();
for (Expression expression : node.getExpressions()) {
@@ -1170,14 +1225,9 @@ protected Scope visitLateral(Lateral node, Optional scope)
return createAndAssignScope(node, scope, queryScope.getRelationType());
}
- private Optional getMaterializedViewStorageTableName(QualifiedObjectName name)
+ private Optional getMaterializedViewStorageTableName(ConnectorMaterializedViewDefinition viewDefinition, QualifiedObjectName name)
{
- Optional optionalView = metadata.getMaterializedView(session, name);
- if (optionalView.isEmpty()) {
- return Optional.empty();
- }
-
- String storageTable = optionalView.get().getStorageTable();
+ String storageTable = viewDefinition.getStorageTable();
if (storageTable == null || storageTable.isEmpty()) {
return Optional.empty();
}
@@ -1217,7 +1267,7 @@ protected Scope visitTable(Table table, Optional scope)
if (optionalMaterializedView.isPresent()) {
if (metadata.getMaterializedViewFreshness(session, name).isMaterializedViewFresh()) {
// If materialized view is current, answer the query using the storage table
- Optional storageName = getMaterializedViewStorageTableName(name);
+ Optional storageName = getMaterializedViewStorageTableName(optionalMaterializedView.get(), name);
if (storageName.isPresent()) {
tableHandle = metadata.getTableHandle(session, createQualifiedObjectName(session, table, storageName.get()));
}
@@ -1245,12 +1295,12 @@ protected Scope visitTable(Table table, Optional scope)
}
throw semanticException(TABLE_NOT_FOUND, table, "Table '%s' does not exist", name);
}
- TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle.get());
+ TableSchema tableSchema = metadata.getTableSchema(session, tableHandle.get());
Map columnHandles = metadata.getColumnHandles(session, tableHandle.get());
// TODO: discover columns lazily based on where they are needed (to support connectors that can't enumerate all tables)
ImmutableList.Builder fields = ImmutableList.builder();
- for (ColumnMetadata column : tableMetadata.getColumns()) {
+ for (ColumnSchema column : tableSchema.getColumns()) {
Field field = Field.newQualified(
table.getName(),
Optional.of(column.getName()),
@@ -1263,6 +1313,7 @@ protected Scope visitTable(Table table, Optional scope)
ColumnHandle columnHandle = columnHandles.get(column.getName());
checkArgument(columnHandle != null, "Unknown field %s", field);
analysis.setColumn(field, columnHandle);
+ analysis.addSourceColumns(field, ImmutableSet.of(new SourceColumn(name, column.getName())));
}
if (updateKind.isPresent()) {
@@ -1394,48 +1445,18 @@ private Scope createScopeForCommonTableExpression(Table table, Optional s
return createAndAssignScope(table, scope, fields);
}
- private Scope createScopeForView(Table table, QualifiedObjectName name, Optional scope, ConnectorViewDefinition view)
+ private Scope createScopeForMaterializedView(Table table, QualifiedObjectName name, Optional scope, ConnectorMaterializedViewDefinition view)
{
- Statement statement = analysis.getStatement();
- if (statement instanceof CreateView) {
- CreateView viewStatement = (CreateView) statement;
- QualifiedObjectName viewNameFromStatement = createQualifiedObjectName(session, viewStatement, viewStatement.getName());
- if (viewStatement.isReplace() && viewNameFromStatement.equals(name)) {
- throw semanticException(VIEW_IS_RECURSIVE, table, "Statement would create a recursive view");
- }
- }
- if (analysis.hasTableInView(table)) {
- throw semanticException(VIEW_IS_RECURSIVE, table, "View is recursive");
- }
-
- Query query = parseView(view.getOriginalSql(), name, table);
- analysis.registerNamedQuery(table, query);
- analysis.registerTableForView(table);
- RelationType descriptor = analyzeView(query, name, view.getCatalog(), view.getSchema(), view.getOwner(), table);
- analysis.unregisterTableForView();
-
- checkViewStaleness(view.getColumns(), descriptor.getVisibleFields(), name, table)
- .ifPresent(explanation -> { throw semanticException(VIEW_IS_STALE, table, "View '%s' is stale or in invalid state: %s", name, explanation); });
-
- // Derive the type of the view from the stored definition, not from the analysis of the underlying query.
- // This is needed in case the underlying table(s) changed and the query in the view now produces types that
- // are implicitly coercible to the declared view types.
- List outputFields = view.getColumns().stream()
- .map(column -> Field.newQualified(
- table.getName(),
- Optional.of(column.getName()),
- getViewColumnType(column, name, table),
- false,
- Optional.of(name),
- Optional.of(column.getName()),
- false))
- .collect(toImmutableList());
-
- analysis.addRelationCoercion(table, outputFields.stream().map(Field::getType).toArray(Type[]::new));
-
- analyzeFiltersAndMasks(table, name, Optional.empty(), outputFields, session.getIdentity().getUser());
-
- return createAndAssignScope(table, scope, outputFields);
+ checkArgument(view.getOwner().isPresent(), "owner must be present");
+ return createScopeForView(
+ table,
+ name,
+ scope,
+ view.getOriginalSql(),
+ view.getCatalog(),
+ view.getSchema(),
+ view.getOwner(),
+ translateMaterializedViewColumns(view.getColumns()));
}
private List translateMaterializedViewColumns(List materializedViewColumns)
@@ -1447,9 +1468,29 @@ private List translateMaterializedViewColumn
return viewColumns;
}
- private Scope createScopeForMaterializedView(Table table, QualifiedObjectName name, Optional scope, ConnectorMaterializedViewDefinition view)
+ private Scope createScopeForView(Table table, QualifiedObjectName name, Optional scope, ConnectorViewDefinition view)
+ {
+ return createScopeForView(table, name, scope, view.getOriginalSql(), view.getCatalog(), view.getSchema(), view.getOwner(), view.getColumns());
+ }
+
+ private Scope createScopeForView(
+ Table table,
+ QualifiedObjectName name,
+ Optional scope,
+ String originalSql,
+ Optional catalog,
+ Optional schema,
+ Optional owner,
+ List columns)
{
Statement statement = analysis.getStatement();
+ if (statement instanceof CreateView) {
+ CreateView viewStatement = (CreateView) statement;
+ QualifiedObjectName viewNameFromStatement = createQualifiedObjectName(session, viewStatement, viewStatement.getName());
+ if (viewStatement.isReplace() && viewNameFromStatement.equals(name)) {
+ throw semanticException(VIEW_IS_RECURSIVE, table, "Statement would create a recursive view");
+ }
+ }
if (statement instanceof CreateMaterializedView) {
CreateMaterializedView viewStatement = (CreateMaterializedView) statement;
QualifiedObjectName viewNameFromStatement = createQualifiedObjectName(session, viewStatement, viewStatement.getName());
@@ -1458,23 +1499,22 @@ private Scope createScopeForMaterializedView(Table table, QualifiedObjectName na
}
}
if (analysis.hasTableInView(table)) {
- throw semanticException(VIEW_IS_RECURSIVE, table, "Materialized View is recursive");
+ throw semanticException(VIEW_IS_RECURSIVE, table, "View is recursive");
}
- Query query = parseView(view.getOriginalSql(), name, table);
+ Query query = parseView(originalSql, name, table);
analysis.registerNamedQuery(table, query);
analysis.registerTableForView(table);
- RelationType descriptor = analyzeView(query, name, view.getCatalog(), view.getSchema(), view.getOwner(), table);
+ RelationType descriptor = analyzeView(query, name, catalog, schema, owner, table);
analysis.unregisterTableForView();
- List viewColumns = translateMaterializedViewColumns(view.getColumns());
- checkViewStaleness(viewColumns, descriptor.getVisibleFields(), name, table)
- .ifPresent(explanation -> { throw semanticException(VIEW_IS_STALE, table, "Materialized View '%s' is stale or in invalid state: %s", name, explanation); });
+ checkViewStaleness(columns, descriptor.getVisibleFields(), name, table)
+ .ifPresent(explanation -> { throw semanticException(VIEW_IS_STALE, table, "View '%s' is stale or in invalid state: %s", name, explanation); });
- // Derive the type of the materialized view from the stored definition, not from the analysis of the underlying query.
- // This is needed in case the underlying table(s) changed and the query in the materialized view now produces types that
- // are implicitly coercible to the declared materialized view types.
- List outputFields = viewColumns.stream()
+ // Derive the type of the view from the stored definition, not from the analysis of the underlying query.
+ // This is needed in case the underlying table(s) changed and the query in the view now produces types that
+ // are implicitly coercible to the declared view types.
+ List outputFields = columns.stream()
.map(column -> Field.newQualified(
table.getName(),
Optional.of(column.getName()),
@@ -1489,6 +1529,7 @@ private Scope createScopeForMaterializedView(Table table, QualifiedObjectName na
analyzeFiltersAndMasks(table, name, Optional.empty(), outputFields, session.getIdentity().getUser());
+ outputFields.forEach(field -> analysis.addSourceColumns(field, ImmutableSet.of(new SourceColumn(name, field.getName().orElseThrow()))));
return createAndAssignScope(table, scope, outputFields);
}
@@ -1879,7 +1920,6 @@ protected Scope visitUpdate(Update update, Optional scope)
List updatedColumns = allColumns.stream()
.filter(column -> assignmentTargets.contains(column.getName()))
.collect(toImmutableList());
- analysis.setUpdateType("UPDATE", tableName, Optional.of(table));
analysis.setUpdatedColumns(updatedColumns);
// Analyzer checks for select permissions but UPDATE has a separate permission, so disable access checks
@@ -1930,6 +1970,14 @@ protected Scope visitUpdate(Update update, Optional scope)
analysis.recordSubqueries(update, analyses.get(index));
}
+ analysis.setUpdateType(
+ "UPDATE",
+ tableName,
+ Optional.of(table),
+ Optional.of(updatedColumns.stream()
+ .map(column -> new OutputColumn(new Column(column.getName(), column.getType().toString()), ImmutableSet.of()))
+ .collect(toImmutableList())));
+
return createAndAssignScope(update, scope, Field.newUnqualified("rows", BIGINT));
}
@@ -2498,7 +2546,9 @@ private Scope computeAndAssignOutputScope(QuerySpecification node, Optional fromReferences = findReferences(from, withQuery.getName());
- if (fromReferences.size() == 0) {
+ if (fromReferences.isEmpty()) {
throw semanticException(INVALID_RECURSIVE_REFERENCE, stepReferences.get(0), "recursive reference outside of FROM clause of the step relation of recursion");
}
@@ -3629,6 +3688,11 @@ private Scope.Builder scopeBuilder(Optional parentScope)
return scopeBuilder;
}
+
+ private OutputColumn createOutputColumn(Field field)
+ {
+ return new OutputColumn(new Column(field.getName().orElseThrow(), field.getType().toString()), analysis.getSourceColumns(field));
+ }
}
private Session createViewSession(Optional catalog, Optional schema, Identity identity, SqlPath path)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java
index 5f130f4ed359..0ca6bc3844d0 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java
@@ -19,8 +19,8 @@
import io.trino.Session;
import io.trino.execution.TableInfo;
import io.trino.metadata.Metadata;
-import io.trino.metadata.TableMetadata;
import io.trino.metadata.TableProperties;
+import io.trino.metadata.TableSchema;
import io.trino.operator.StageExecutionDescriptor;
import io.trino.server.DynamicFilterService;
import io.trino.spi.connector.DynamicFilter;
@@ -149,9 +149,9 @@ private StageExecutionPlan doPlan(SubPlan root, Session session, ImmutableList.B
private TableInfo getTableInfo(TableScanNode node, Session session)
{
- TableMetadata tableMetadata = metadata.getTableMetadata(session, node.getTable());
+ TableSchema tableSchema = metadata.getTableSchema(session, node.getTable());
TableProperties tableProperties = metadata.getTableProperties(session, node.getTable());
- return new TableInfo(tableMetadata.getQualifiedName(), tableProperties.getPredicate());
+ return new TableInfo(tableSchema.getQualifiedName(), tableProperties.getPredicate());
}
private final class Visitor
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/InputExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/InputExtractor.java
index 39ea0222ad4a..8db19fcf1447 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/InputExtractor.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/InputExtractor.java
@@ -63,7 +63,7 @@ private static Column createColumn(ColumnMetadata columnMetadata)
private Input createInput(Session session, TableHandle table, Set columns, PlanFragmentId fragmentId, PlanNodeId planNodeId)
{
- SchemaTableName schemaTable = metadata.getTableMetadata(session, table).getTable();
+ SchemaTableName schemaTable = metadata.getTableSchema(session, table).getTable();
Optional inputMetadata = metadata.getInfo(session, table);
return new Input(table.getCatalogName().getCatalogName(), schemaTable.getSchemaName(), schemaTable.getTableName(), inputMetadata, ImmutableList.copyOf(columns), fragmentId, planNodeId);
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
index 767b7407c971..49f700050df1 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
@@ -1743,6 +1743,7 @@ public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanCo
lookupSourceFactoryManager,
probeSource.getTypes(),
false,
+ false,
probeChannels,
probeHashChannel,
Optional.empty(),
@@ -2110,11 +2111,23 @@ private PhysicalOperation createLookupJoin(
PhysicalOperation probeSource = probeNode.accept(this, context);
// Plan build
- boolean spillEnabled = isSpillEnabled(session) && node.isSpillable().orElseThrow(() -> new IllegalArgumentException("spillable not yet set"));
+ boolean buildOuter = node.getType() == RIGHT || node.getType() == FULL;
+ boolean spillEnabled = isSpillEnabled(session)
+ && node.isSpillable().orElseThrow(() -> new IllegalArgumentException("spillable not yet set"))
+ && probeSource.getPipelineExecutionStrategy() == UNGROUPED_EXECUTION
+ && !buildOuter;
JoinBridgeManager