diff --git a/docs/src/main/sphinx/connector/cassandra.rst b/docs/src/main/sphinx/connector/cassandra.rst index 8d99f955b382..c6b77ab4e2e0 100644 --- a/docs/src/main/sphinx/connector/cassandra.rst +++ b/docs/src/main/sphinx/connector/cassandra.rst @@ -76,9 +76,11 @@ Property Name Description This is a global setting used for all connections, regardless of the user connected to Trino. -``cassandra.protocol-version`` It is possible to override the protocol version for older Cassandra clusters. - By default, the values from the highest protocol version the driver can use. - Possible values include ``V2``, ``V3`` and ``V4``. +``cassandra.protocol-version`` It is possible to override the protocol version for older Cassandra + clusters. + By default, the value corresponds to the default protocol version + used in the underlying Cassandra java driver. + Possible values include ``V3``, ``V4``, ``V5``, ``V6``. ================================================== ====================================================================== .. note:: diff --git a/plugin/trino-cassandra/pom.xml b/plugin/trino-cassandra/pom.xml index 5ef3f425c6f7..8a2688190f3e 100644 --- a/plugin/trino-cassandra/pom.xml +++ b/plugin/trino-cassandra/pom.xml @@ -14,6 +14,8 @@ ${project.parent.basedir} + 4.14.0 + 1.5.1 @@ -22,11 +24,6 @@ trino-plugin-toolkit - - io.trino.cassandra - cassandra-driver - - io.airlift bootstrap @@ -49,12 +46,37 @@ io.airlift - security + units - io.airlift - units + com.datastax.oss + java-driver-core + ${dep.casandra.version} + + + org.ow2.asm + asm-analysis + + + + + + com.datastax.oss + java-driver-query-builder + ${dep.casandra.version} + + + com.github.spotbugs + spotbugs-annotations + + + + + + com.datastax.oss + native-protocol + ${dep.native-protocol.version} @@ -93,11 +115,6 @@ validation-api - - joda-time - joda-time - - org.weakref jmxutils diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/BackoffRetryPolicy.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/BackoffRetryPolicy.java index 574687b83e98..524191008eb7 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/BackoffRetryPolicy.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/BackoffRetryPolicy.java @@ -13,63 +13,105 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.Cluster; -import com.datastax.driver.core.ConsistencyLevel; -import com.datastax.driver.core.Statement; -import com.datastax.driver.core.WriteType; -import com.datastax.driver.core.exceptions.DriverException; -import com.datastax.driver.core.policies.DefaultRetryPolicy; -import com.datastax.driver.core.policies.RetryPolicy; +import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.context.DriverContext; +import com.datastax.oss.driver.api.core.retry.RetryDecision; +import com.datastax.oss.driver.api.core.retry.RetryPolicy; +import com.datastax.oss.driver.api.core.servererrors.CoordinatorException; +import com.datastax.oss.driver.api.core.servererrors.DefaultWriteType; +import com.datastax.oss.driver.api.core.servererrors.WriteType; +import com.datastax.oss.driver.api.core.session.Request; +import io.airlift.log.Logger; import java.util.concurrent.ThreadLocalRandom; public class BackoffRetryPolicy implements RetryPolicy { - public static final BackoffRetryPolicy INSTANCE = new BackoffRetryPolicy(); + private static final Logger log = Logger.get(BackoffRetryPolicy.class); - private BackoffRetryPolicy() {} + private final String logPrefix; + + public BackoffRetryPolicy(DriverContext context, String profileName) + { + this.logPrefix = (context != null ? context.getSessionName() : null) + "|" + profileName; + } + + @Override + public RetryDecision onReadTimeout(Request request, ConsistencyLevel consistencyLevel, int blockFor, int received, boolean dataPresent, int retryCount) + { + RetryDecision decision = + (retryCount == 0 && received >= blockFor && !dataPresent) + ? RetryDecision.RETRY_SAME + : RetryDecision.RETHROW; + + if (decision == RetryDecision.RETRY_SAME) { + log.debug( + "[%s] Retrying on read timeout on same host (consistency: %s, required responses: %s, received responses: %s, data retrieved: %s, retries: %s)", + logPrefix, + consistencyLevel, + blockFor, + received, + false, + retryCount); + } + + return decision; + } + + @Override + public RetryDecision onWriteTimeout(Request request, ConsistencyLevel consistencyLevel, WriteType writeType, int blockFor, int received, int retryCount) + { + RetryDecision decision = + (retryCount == 0 && writeType == DefaultWriteType.BATCH_LOG) + ? RetryDecision.RETRY_SAME + : RetryDecision.RETHROW; + + if (decision == RetryDecision.RETRY_SAME && log.isDebugEnabled()) { + log.debug( + "[%s] Retrying on write timeout on same host (consistency: %s, write type: %s, required acknowledgments: %s, received acknowledgments: %s, retries: %s)", + logPrefix, + consistencyLevel, + writeType, + blockFor, + received, + retryCount); + } + return decision; + } @Override - public RetryDecision onUnavailable(Statement statement, ConsistencyLevel consistencyLevel, int requiredReplica, int aliveReplica, int retries) + public RetryDecision onUnavailable(Request request, ConsistencyLevel consistencyLevel, int required, int alive, int retries) { if (retries >= 10) { - return RetryDecision.rethrow(); + return RetryDecision.RETHROW; } try { int jitter = ThreadLocalRandom.current().nextInt(100); int delay = (100 * (retries + 1)) + jitter; Thread.sleep(delay); - return RetryDecision.retry(consistencyLevel); + return RetryDecision.RETRY_SAME; } catch (InterruptedException e) { Thread.currentThread().interrupt(); - return RetryDecision.rethrow(); + return RetryDecision.RETHROW; } } @Override - public RetryDecision onReadTimeout(Statement statement, ConsistencyLevel cl, int requiredResponses, int receivedResponses, boolean dataRetrieved, int nbRetry) + public RetryDecision onRequestAborted(Request request, Throwable error, int retryCount) { - return DefaultRetryPolicy.INSTANCE.onReadTimeout(statement, cl, requiredResponses, receivedResponses, dataRetrieved, nbRetry); + return RetryDecision.RETHROW; } @Override - public RetryDecision onWriteTimeout(Statement statement, ConsistencyLevel cl, WriteType writeType, int requiredAcks, int receivedAcks, int nbRetry) + public RetryDecision onErrorResponse(Request request, CoordinatorException error, int retryCount) { - return DefaultRetryPolicy.INSTANCE.onWriteTimeout(statement, cl, writeType, requiredAcks, receivedAcks, nbRetry); + log.debug(error, "[%s] Retrying on node error on next host (retries: %s)", logPrefix, retryCount); + return RetryDecision.RETRY_NEXT; } - @Override - public RetryDecision onRequestError(Statement statement, ConsistencyLevel cl, DriverException e, int nbRetry) - { - return RetryDecision.tryNextHost(cl); - } - - @Override - public void init(Cluster cluster) {} - @Override public void close() {} } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientConfig.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientConfig.java index 6a82d87f0e4e..2e0858bd51fb 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientConfig.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientConfig.java @@ -13,9 +13,10 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.ConsistencyLevel; -import com.datastax.driver.core.ProtocolVersion; -import com.datastax.driver.core.SocketOptions; +import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.DefaultConsistencyLevel; +import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.ProtocolVersion; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import io.airlift.configuration.Config; @@ -58,8 +59,8 @@ public class CassandraClientConfig private boolean allowDropTable; private String username; private String password; - private Duration clientReadTimeout = new Duration(SocketOptions.DEFAULT_READ_TIMEOUT_MILLIS, MILLISECONDS); - private Duration clientConnectTimeout = new Duration(SocketOptions.DEFAULT_CONNECT_TIMEOUT_MILLIS, MILLISECONDS); + private Duration clientReadTimeout = new Duration(12_000, MILLISECONDS); + private Duration clientConnectTimeout = new Duration(5_000, MILLISECONDS); private Integer clientSoLinger; private RetryPolicyType retryPolicy = RetryPolicyType.DEFAULT; private boolean useDCAware; @@ -119,7 +120,7 @@ public ConsistencyLevel getConsistencyLevel() } @Config("cassandra.consistency-level") - public CassandraClientConfig setConsistencyLevel(ConsistencyLevel level) + public CassandraClientConfig setConsistencyLevel(DefaultConsistencyLevel level) { this.consistencyLevel = level; return this; @@ -411,7 +412,7 @@ public ProtocolVersion getProtocolVersion() } @Config("cassandra.protocol-version") - public CassandraClientConfig setProtocolVersion(ProtocolVersion version) + public CassandraClientConfig setProtocolVersion(DefaultProtocolVersion version) { this.protocolVersion = version; return this; diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java index ce8709dc0146..f26e53f4d74f 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java @@ -13,17 +13,12 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.Cluster; -import com.datastax.driver.core.JdkSSLOptions; -import com.datastax.driver.core.QueryOptions; -import com.datastax.driver.core.SocketOptions; -import com.datastax.driver.core.policies.ConstantSpeculativeExecutionPolicy; -import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy; -import com.datastax.driver.core.policies.ExponentialReconnectionPolicy; -import com.datastax.driver.core.policies.LoadBalancingPolicy; -import com.datastax.driver.core.policies.RoundRobinPolicy; -import com.datastax.driver.core.policies.TokenAwarePolicy; -import com.datastax.driver.core.policies.WhiteListPolicy; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.config.ProgrammaticDriverConfigLoaderBuilder; +import com.datastax.oss.driver.internal.core.loadbalancing.DefaultLoadBalancingPolicy; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; import com.google.inject.Binder; @@ -31,7 +26,6 @@ import com.google.inject.Provides; import com.google.inject.Scopes; import io.airlift.json.JsonCodec; -import io.airlift.security.pem.PemReader; import io.trino.spi.TrinoException; import io.trino.spi.type.Type; import io.trino.spi.type.TypeId; @@ -40,20 +34,14 @@ import javax.inject.Inject; import javax.inject.Singleton; import javax.net.ssl.SSLContext; -import javax.security.auth.x500.X500Principal; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; -import java.io.InputStream; +import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.UnknownHostException; import java.security.GeneralSecurityException; -import java.security.KeyStore; -import java.security.cert.Certificate; -import java.security.cert.CertificateExpiredException; -import java.security.cert.CertificateNotYetValidException; -import java.security.cert.X509Certificate; -import java.util.ArrayList; +import java.time.Duration; import java.util.List; import java.util.Optional; @@ -64,7 +52,6 @@ import static io.trino.plugin.base.ssl.SslUtils.createSSLContext; import static io.trino.plugin.cassandra.CassandraErrorCode.CASSANDRA_SSL_INITIALIZATION_FAILURE; import static java.lang.Math.toIntExact; -import static java.util.Collections.list; import static java.util.Objects.requireNonNull; public class CassandraClientModule @@ -123,80 +110,70 @@ public static CassandraSession createCassandraSession(CassandraClientConfig conf requireNonNull(config, "config is null"); requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); - Cluster.Builder clusterBuilder = Cluster.builder(); + CqlSessionBuilder cqlSessionBuilder = CqlSession.builder(); + ProgrammaticDriverConfigLoaderBuilder driverConfigLoaderBuilder = DriverConfigLoader.programmaticBuilder(); + // allow the retrieval of metadata for the system keyspaces + driverConfigLoaderBuilder.withStringList(DefaultDriverOption.METADATA_SCHEMA_REFRESHED_KEYSPACES, List.of()); + if (config.getProtocolVersion() != null) { - clusterBuilder.withProtocolVersion(config.getProtocolVersion()); + driverConfigLoaderBuilder.withString(DefaultDriverOption.PROTOCOL_VERSION, config.getProtocolVersion().name()); } List contactPoints = requireNonNull(config.getContactPoints(), "contactPoints is null"); checkArgument(!contactPoints.isEmpty(), "empty contactPoints"); - clusterBuilder.withPort(config.getNativeProtocolPort()); - clusterBuilder.withReconnectionPolicy(new ExponentialReconnectionPolicy(500, 10000)); - clusterBuilder.withRetryPolicy(config.getRetryPolicy().getPolicy()); - LoadBalancingPolicy loadPolicy = new RoundRobinPolicy(); + driverConfigLoaderBuilder.withString(DefaultDriverOption.RECONNECTION_POLICY_CLASS, com.datastax.oss.driver.internal.core.connection.ExponentialReconnectionPolicy.class.getName()); + driverConfigLoaderBuilder.withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(500)); + driverConfigLoaderBuilder.withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(10_000)); + driverConfigLoaderBuilder.withString(DefaultDriverOption.RETRY_POLICY_CLASS, config.getRetryPolicy().getPolicyClass().getName()); + driverConfigLoaderBuilder.withString(DefaultDriverOption.LOAD_BALANCING_POLICY_CLASS, DefaultLoadBalancingPolicy.class.getName()); if (config.isUseDCAware()) { requireNonNull(config.getDcAwareLocalDC(), "DCAwarePolicy localDC is null"); - DCAwareRoundRobinPolicy.Builder builder = DCAwareRoundRobinPolicy.builder() - .withLocalDc(config.getDcAwareLocalDC()); + driverConfigLoaderBuilder.withString(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER, config.getDcAwareLocalDC()); + if (config.getDcAwareUsedHostsPerRemoteDc() > 0) { - builder.withUsedHostsPerRemoteDc(config.getDcAwareUsedHostsPerRemoteDc()); + driverConfigLoaderBuilder.withInt(DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_MAX_NODES_PER_REMOTE_DC, config.getDcAwareUsedHostsPerRemoteDc()); if (config.isDcAwareAllowRemoteDCsForLocal()) { - builder.allowRemoteDCsForLocalConsistencyLevel(); + driverConfigLoaderBuilder.withBoolean(DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_ALLOW_FOR_LOCAL_CONSISTENCY_LEVELS, true); } } - loadPolicy = builder.build(); - } - - if (config.isUseTokenAware()) { - loadPolicy = new TokenAwarePolicy(loadPolicy, config.isTokenAwareShuffleReplicas()); - } - - if (!config.getAllowedAddresses().isEmpty()) { - checkArgument(!config.getAllowedAddresses().isEmpty(), "empty AllowListAddresses"); - List allowList = new ArrayList<>(); - for (String point : config.getAllowedAddresses()) { - allowList.add(new InetSocketAddress(point, config.getNativeProtocolPort())); - } - loadPolicy = new WhiteListPolicy(loadPolicy, allowList); } - clusterBuilder.withLoadBalancingPolicy(loadPolicy); - - SocketOptions socketOptions = new SocketOptions(); - socketOptions.setReadTimeoutMillis(toIntExact(config.getClientReadTimeout().toMillis())); - socketOptions.setConnectTimeoutMillis(toIntExact(config.getClientConnectTimeout().toMillis())); + driverConfigLoaderBuilder.withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofMillis(toIntExact(config.getClientReadTimeout().toMillis()))); + driverConfigLoaderBuilder.withDuration(DefaultDriverOption.CONNECTION_CONNECT_TIMEOUT, Duration.ofMillis(toIntExact(config.getClientConnectTimeout().toMillis()))); if (config.getClientSoLinger() != null) { - socketOptions.setSoLinger(config.getClientSoLinger()); + driverConfigLoaderBuilder.withInt(DefaultDriverOption.SOCKET_LINGER_INTERVAL, config.getClientSoLinger()); } if (config.isTlsEnabled()) { buildSslContext(config.getKeystorePath(), config.getKeystorePassword(), config.getTruststorePath(), config.getTruststorePassword()) - .ifPresent(context -> clusterBuilder.withSSL(JdkSSLOptions.builder().withSSLContext(context).build())); + .ifPresent(cqlSessionBuilder::withSslContext); } - clusterBuilder.withSocketOptions(socketOptions); if (config.getUsername() != null && config.getPassword() != null) { - clusterBuilder.withCredentials(config.getUsername(), config.getPassword()); + cqlSessionBuilder.withAuthCredentials(config.getUsername(), config.getPassword()); } - QueryOptions options = new QueryOptions(); - options.setFetchSize(config.getFetchSize()); - options.setConsistencyLevel(config.getConsistencyLevel()); - clusterBuilder.withQueryOptions(options); + driverConfigLoaderBuilder.withInt(DefaultDriverOption.REQUEST_PAGE_SIZE, config.getFetchSize()); + driverConfigLoaderBuilder.withString(DefaultDriverOption.REQUEST_CONSISTENCY, config.getConsistencyLevel().name()); if (config.getSpeculativeExecutionLimit().isPresent()) { - clusterBuilder.withSpeculativeExecutionPolicy(new ConstantSpeculativeExecutionPolicy( - config.getSpeculativeExecutionDelay().toMillis(), // delay before a new execution is launched - config.getSpeculativeExecutionLimit().get())); // maximum number of executions + driverConfigLoaderBuilder.withString(DefaultDriverOption.SPECULATIVE_EXECUTION_POLICY_CLASS, com.datastax.oss.driver.internal.core.specex.ConstantSpeculativeExecutionPolicy.class.getName()); + // maximum number of executions + driverConfigLoaderBuilder.withInt(DefaultDriverOption.SPECULATIVE_EXECUTION_MAX, config.getSpeculativeExecutionLimit().get()); + // delay before a new execution is launched + driverConfigLoaderBuilder.withDuration(DefaultDriverOption.SPECULATIVE_EXECUTION_DELAY, Duration.ofMillis(config.getSpeculativeExecutionDelay().toMillis())); } + cqlSessionBuilder.withConfigLoader(driverConfigLoaderBuilder.build()); + return new CassandraSession( extraColumnMetadataCodec, - new ReopeningCluster(() -> { - contactPoints.forEach(clusterBuilder::addContactPoint); - return clusterBuilder.build(); - }), + () -> { + contactPoints.forEach(contactPoint -> cqlSessionBuilder.addContactPoint( + createInetSocketAddress(contactPoint, config.getNativeProtocolPort()))); + return cqlSessionBuilder.build(); + }, config.getNoHostAvailableRetryTimeout()); } @@ -218,52 +195,13 @@ private static Optional buildSslContext( } } - private static KeyStore loadTrustStore(File trustStorePath, Optional trustStorePassword) - throws IOException, GeneralSecurityException + private static InetSocketAddress createInetSocketAddress(String contactPoint, int port) { - KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); try { - // attempt to read the trust store as a PEM file - List certificateChain = PemReader.readCertificateChain(trustStorePath); - if (!certificateChain.isEmpty()) { - trustStore.load(null, null); - for (X509Certificate certificate : certificateChain) { - X500Principal principal = certificate.getSubjectX500Principal(); - trustStore.setCertificateEntry(principal.getName(), certificate); - } - return trustStore; - } - } - catch (IOException | GeneralSecurityException ignored) { + return new InetSocketAddress(InetAddress.getByName(contactPoint), port); } - - try (InputStream in = new FileInputStream(trustStorePath)) { - trustStore.load(in, trustStorePassword.map(String::toCharArray).orElse(null)); - } - return trustStore; - } - - private static void validateCertificates(KeyStore keyStore) - throws GeneralSecurityException - { - for (String alias : list(keyStore.aliases())) { - if (!keyStore.isKeyEntry(alias)) { - continue; - } - Certificate certificate = keyStore.getCertificate(alias); - if (!(certificate instanceof X509Certificate)) { - continue; - } - - try { - ((X509Certificate) certificate).checkValidity(); - } - catch (CertificateExpiredException e) { - throw new CertificateExpiredException("KeyStore certificate is expired: " + e.getMessage()); - } - catch (CertificateNotYetValidException e) { - throw new CertificateNotYetValidException("KeyStore certificate is not yet valid: " + e.getMessage()); - } + catch (UnknownHostException e) { + throw new IllegalArgumentException("Failed to add contact point: " + contactPoint, e); } } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java index fe17c8a9d45f..6c062539c0fa 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java @@ -13,7 +13,7 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.VersionNumber; +import com.datastax.oss.driver.api.core.Version; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -36,7 +36,7 @@ public class CassandraClusteringPredicatesExtractor private final ClusteringPushDownResult clusteringPushDownResult; private final TupleDomain predicates; - public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) + public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates, Version cassandraVersion) { this.predicates = requireNonNull(predicates, "predicates is null"); this.clusteringPushDownResult = getClusteringKeysSet(clusteringColumns, predicates, requireNonNull(cassandraVersion, "cassandraVersion is null")); @@ -52,7 +52,7 @@ public TupleDomain getUnenforcedConstraints() return predicates.filter(((columnHandle, domain) -> !clusteringPushDownResult.hasBeenFullyPushed(columnHandle))); } - private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) + private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, Version cassandraVersion) { ImmutableSet.Builder fullyPushedColumnPredicates = ImmutableSet.builder(); ImmutableList.Builder clusteringColumnSql = ImmutableList.builder(); @@ -127,9 +127,9 @@ private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, VersionNumber cassandraVersion, int currentlyProcessedClusteringColumn) + private static boolean isInExpressionNotAllowed(List clusteringColumns, Version cassandraVersion, int currentlyProcessedClusteringColumn) { - return cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentlyProcessedClusteringColumn != (clusteringColumns.size() - 1); + return cassandraVersion.compareTo(Version.parse("2.2.0")) < 0 && currentlyProcessedClusteringColumn != (clusteringColumns.size() - 1); } private static String toCqlLiteral(CassandraColumnHandle columnHandle, Object value) diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index f9692c04d1ed..cf1f9a45a54c 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -50,7 +50,7 @@ import java.util.OptionalLong; import java.util.stream.Collectors; -import static com.datastax.driver.core.querybuilder.QueryBuilder.truncate; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.truncate; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.toOptional; @@ -59,6 +59,8 @@ import static io.trino.plugin.cassandra.util.CassandraCqlUtils.cqlNameToSqlName; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteral; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validColumnName; +import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validSchemaName; +import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validTableName; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED; import static java.lang.String.format; @@ -329,7 +331,7 @@ public Optional finishCreateTable(ConnectorSession sess public void truncateTable(ConnectorSession session, ConnectorTableHandle tableHandle) { CassandraTableHandle table = (CassandraTableHandle) tableHandle; - cassandraSession.execute(truncate(table.getSchemaName(), table.getTableName())); + cassandraSession.execute(truncate(validSchemaName(table.getSchemaName()), validTableName(table.getTableName())).build()); } @Override diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java index 1804d593a631..948b63758c2d 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.BatchStatement; -import com.datastax.driver.core.LocalDate; -import com.datastax.driver.core.PreparedStatement; -import com.datastax.driver.core.ProtocolVersion; -import com.datastax.driver.core.querybuilder.Insert; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.cql.BatchStatement; +import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; +import com.datastax.oss.driver.api.core.cql.DefaultBatchType; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.querybuilder.term.Term; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import io.airlift.slice.Slice; @@ -29,20 +32,19 @@ import io.trino.spi.type.Type; import io.trino.spi.type.UuidType; import io.trino.spi.type.VarcharType; -import org.joda.time.format.DateTimeFormatter; -import org.joda.time.format.ISODateTimeFormat; -import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import java.util.function.Function; -import static com.datastax.driver.core.querybuilder.QueryBuilder.bindMarker; -import static com.datastax.driver.core.querybuilder.QueryBuilder.insertInto; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.bindMarker; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.insertInto; import static com.google.common.base.Preconditions.checkArgument; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.ID_COLUMN_NAME; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validColumnName; @@ -69,15 +71,13 @@ public class CassandraPageSink implements ConnectorPageSink { - private static final DateTimeFormatter DATE_FORMATTER = ISODateTimeFormat.date().withZoneUTC(); - private final CassandraSession cassandraSession; private final PreparedStatement insert; private final List columnTypes; private final boolean generateUuid; private final int batchSize; private final Function toCassandraDate; - private final BatchStatement batchStatement = new BatchStatement(); + private final BatchStatementBuilder batchStatement = BatchStatement.builder(DefaultBatchType.LOGGED); public CassandraPageSink( CassandraSession cassandraSession, @@ -97,23 +97,26 @@ public CassandraPageSink( this.generateUuid = generateUuid; this.batchSize = batchSize; - if (protocolVersion.toInt() <= ProtocolVersion.V3.toInt()) { - this.toCassandraDate = value -> DATE_FORMATTER.print(TimeUnit.DAYS.toMillis(value)); + if (protocolVersion.getCode() <= ProtocolVersion.V3.getCode()) { + toCassandraDate = value -> DateTimeFormatter.ISO_LOCAL_DATE.format(LocalDate.ofEpochDay(toIntExact(value))); } else { - this.toCassandraDate = value -> LocalDate.fromDaysSinceEpoch(toIntExact(value)); + toCassandraDate = value -> LocalDate.ofEpochDay(toIntExact(value)); } - Insert insert = insertInto(validSchemaName(schemaName), validTableName(tableName)); + ImmutableMap.Builder parameters = ImmutableMap.builder(); if (generateUuid) { - insert.value(ID_COLUMN_NAME, bindMarker()); + parameters.put(ID_COLUMN_NAME, bindMarker()); } for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); checkArgument(columnName != null, "columnName is null at position: %s", i); - insert.value(validColumnName(columnName), bindMarker()); + parameters.put(validColumnName(columnName), bindMarker()); } - this.insert = cassandraSession.prepare(insert); + SimpleStatement insertStatement = insertInto(validSchemaName(schemaName), validTableName(tableName)) + .values(parameters.buildOrThrow()) + .build(); + this.insert = cassandraSession.prepare(insertStatement); } @Override @@ -129,11 +132,11 @@ public CompletableFuture appendPage(Page page) appendColumn(values, page, position, channel); } - batchStatement.add(insert.bind(values.toArray())); + batchStatement.addStatement(insert.bind(values.toArray())); - if (batchStatement.size() >= batchSize) { - cassandraSession.execute(batchStatement); - batchStatement.clear(); + if (batchStatement.getStatementsCount() >= batchSize) { + cassandraSession.execute(batchStatement.build()); + batchStatement.clearStatements(); } } return NOT_BLOCKED; @@ -171,7 +174,7 @@ else if (DATE.equals(type)) { values.add(toCassandraDate.apply(type.getLong(block, position))); } else if (TIMESTAMP_TZ_MILLIS.equals(type)) { - values.add(new Timestamp(unpackMillisUtc(type.getLong(block, position)))); + values.add(Instant.ofEpochMilli(unpackMillisUtc(type.getLong(block, position)))); } else if (type instanceof VarcharType) { values.add(type.getSlice(block, position).toStringUtf8()); @@ -190,9 +193,9 @@ else if (UuidType.UUID.equals(type)) { @Override public CompletableFuture> finish() { - if (batchStatement.size() > 0) { - cassandraSession.execute(batchStatement); - batchStatement.clear(); + if (batchStatement.getStatementsCount() > 0) { + cassandraSession.execute(batchStatement.build()); + batchStatement.clearStatements(); } // the committer does not need any additional info diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java index d4d30102d153..986a48eb82a1 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.ResultSet; -import com.datastax.driver.core.Row; +import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.Row; import io.airlift.slice.Slice; import io.trino.plugin.cassandra.CassandraType.Kind; import io.trino.spi.connector.RecordCursor; @@ -45,8 +45,9 @@ public CassandraRecordCursor(CassandraSession cassandraSession, List (CassandraColumnHandle) column) .collect(toList()); - String selectCql = CassandraCqlUtils.selectFrom(cassandraTable, cassandraColumns).getQueryString(); + String selectCql = CassandraCqlUtils.selectFrom(cassandraTable, cassandraColumns).asCql(); StringBuilder sb = new StringBuilder(selectCql); if (sb.charAt(sb.length() - 1) == ';') { sb.setLength(sb.length() - 1); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java index d48d0b2de8e2..67677b6323f1 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java @@ -13,30 +13,35 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.AbstractTableMetadata; -import com.datastax.driver.core.Cluster; -import com.datastax.driver.core.ColumnMetadata; -import com.datastax.driver.core.DataType; -import com.datastax.driver.core.Host; -import com.datastax.driver.core.IndexMetadata; -import com.datastax.driver.core.KeyspaceMetadata; -import com.datastax.driver.core.MaterializedViewMetadata; -import com.datastax.driver.core.PreparedStatement; -import com.datastax.driver.core.ProtocolVersion; -import com.datastax.driver.core.RegularStatement; -import com.datastax.driver.core.ResultSet; -import com.datastax.driver.core.Row; -import com.datastax.driver.core.Session; -import com.datastax.driver.core.Statement; -import com.datastax.driver.core.TableMetadata; -import com.datastax.driver.core.TokenRange; -import com.datastax.driver.core.VersionNumber; -import com.datastax.driver.core.exceptions.NoHostAvailableException; -import com.datastax.driver.core.policies.ReconnectionPolicy; -import com.datastax.driver.core.policies.ReconnectionPolicy.ReconnectionSchedule; -import com.datastax.driver.core.querybuilder.Clause; -import com.datastax.driver.core.querybuilder.QueryBuilder; -import com.datastax.driver.core.querybuilder.Select; +import com.datastax.oss.driver.api.core.AllNodesFailedException; +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.Version; +import com.datastax.oss.driver.api.core.connection.ReconnectionPolicy; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.cql.Statement; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.KeyspaceMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.RelationMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.ViewMetadata; +import com.datastax.oss.driver.api.core.metadata.token.TokenRange; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.ListType; +import com.datastax.oss.driver.api.core.type.MapType; +import com.datastax.oss.driver.api.core.type.SetType; +import com.datastax.oss.driver.api.core.type.TupleType; +import com.datastax.oss.driver.api.core.type.UserDefinedType; +import com.datastax.oss.driver.api.querybuilder.QueryBuilder; +import com.datastax.oss.driver.api.querybuilder.relation.Relation; +import com.datastax.oss.driver.api.querybuilder.select.Select; +import com.datastax.oss.driver.api.querybuilder.term.Term; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; @@ -53,8 +58,11 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; +import java.io.Closeable; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -62,10 +70,11 @@ import java.util.Optional; import java.util.Set; import java.util.function.Supplier; +import java.util.stream.IntStream; import java.util.stream.Stream; -import static com.datastax.driver.core.querybuilder.QueryBuilder.eq; -import static com.datastax.driver.core.querybuilder.QueryBuilder.select; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.selectFrom; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Suppliers.memoize; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -77,6 +86,7 @@ import static io.trino.plugin.cassandra.CassandraType.toCassandraType; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.selectDistinctFrom; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validSchemaName; +import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validTableName; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.lang.String.format; import static java.util.Comparator.comparing; @@ -86,27 +96,26 @@ import static java.util.stream.Collectors.toList; public class CassandraSession + implements Closeable { private static final Logger log = Logger.get(CassandraSession.class); private static final String SYSTEM = "system"; private static final String SIZE_ESTIMATES = "size_estimates"; - private static final VersionNumber PARTITION_FETCH_WITH_IN_PREDICATE_VERSION = VersionNumber.parse("2.2"); + private static final Version PARTITION_FETCH_WITH_IN_PREDICATE_VERSION = Version.parse("2.2"); private final JsonCodec> extraColumnMetadataCodec; - private final Cluster cluster; - private final Supplier session; + private final Supplier session; private final Duration noHostAvailableRetryTimeout; - public CassandraSession(JsonCodec> extraColumnMetadataCodec, Cluster cluster, Duration noHostAvailableRetryTimeout) + public CassandraSession(JsonCodec> extraColumnMetadataCodec, Supplier sessionSupplier, Duration noHostAvailableRetryTimeout) { this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); - this.cluster = requireNonNull(cluster, "cluster is null"); this.noHostAvailableRetryTimeout = requireNonNull(noHostAvailableRetryTimeout, "noHostAvailableRetryTimeout is null"); - this.session = memoize(cluster::connect); + this.session = memoize(sessionSupplier::get); } - public VersionNumber getCassandraVersion() + public Version getCassandraVersion() { ResultSet result = executeWithSession(session -> session.execute("select release_version from system.local")); Row versionRow = result.one(); @@ -115,51 +124,60 @@ public VersionNumber getCassandraVersion() "Please make sure that the Cassandra cluster is up and running, " + "and that the contact points are specified correctly."); } - return VersionNumber.parse(versionRow.getString("release_version")); + return Version.parse(versionRow.getString("release_version")); } public ProtocolVersion getProtocolVersion() { - return executeWithSession(session -> session.getCluster().getConfiguration().getProtocolOptions().getProtocolVersion()); + return executeWithSession(session -> session.getContext().getProtocolVersion()); } public String getPartitioner() { - return executeWithSession(session -> session.getCluster().getMetadata().getPartitioner()); + return executeWithSession(session -> session.getMetadata().getTokenMap() + .orElseThrow() + .getPartitionerName()); } public Set getTokenRanges() { - return executeWithSession(session -> session.getCluster().getMetadata().getTokenRanges()); + return executeWithSession(session -> session.getMetadata().getTokenMap() + .orElseThrow() + .getTokenRanges()); } - public Set getReplicas(String caseSensitiveSchemaName, TokenRange tokenRange) + public Set getReplicas(String caseSensitiveSchemaName, TokenRange tokenRange) { requireNonNull(caseSensitiveSchemaName, "caseSensitiveSchemaName is null"); requireNonNull(tokenRange, "tokenRange is null"); return executeWithSession(session -> - session.getCluster().getMetadata().getReplicas(validSchemaName(caseSensitiveSchemaName), tokenRange)); + session.getMetadata() + .getTokenMap() + .map(tokenMap -> tokenMap.getReplicas(validSchemaName(caseSensitiveSchemaName), tokenRange)) + .orElse(ImmutableSet.of())); } - public Set getReplicas(String caseSensitiveSchemaName, ByteBuffer partitionKey) + public Set getReplicas(String caseSensitiveSchemaName, ByteBuffer partitionKey) { requireNonNull(caseSensitiveSchemaName, "caseSensitiveSchemaName is null"); requireNonNull(partitionKey, "partitionKey is null"); return executeWithSession(session -> - session.getCluster().getMetadata().getReplicas(validSchemaName(caseSensitiveSchemaName), partitionKey)); + session.getMetadata().getTokenMap() + .map(tokenMap -> tokenMap.getReplicas(validSchemaName(caseSensitiveSchemaName), partitionKey)) + .orElse(ImmutableSet.of())); } public String getCaseSensitiveSchemaName(String caseInsensitiveSchemaName) { - return getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName).getName(); + return getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName).getName().asInternal(); } public List getCaseSensitiveSchemaNames() { ImmutableList.Builder builder = ImmutableList.builder(); - List keyspaces = executeWithSession(session -> session.getCluster().getMetadata().getKeyspaces()); - for (KeyspaceMetadata meta : keyspaces) { - builder.add(meta.getName()); + Map keyspaces = executeWithSession(session -> session.getMetadata().getKeyspaces()); + for (KeyspaceMetadata meta : keyspaces.values()) { + builder.add(meta.getName().asInternal()); } return builder.build(); } @@ -169,11 +187,11 @@ public List getCaseSensitiveTableNames(String caseInsensitiveSchemaName) { KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName); ImmutableList.Builder builder = ImmutableList.builder(); - for (TableMetadata table : keyspace.getTables()) { - builder.add(table.getName()); + for (TableMetadata table : keyspace.getTables().values()) { + builder.add(table.getName().asInternal()); } - for (MaterializedViewMetadata materializedView : keyspace.getMaterializedViews()) { - builder.add(materializedView.getName()); + for (ViewMetadata materializedView : keyspace.getViews().values()) { + builder.add(materializedView.getName().asInternal()); } return builder.build(); } @@ -182,20 +200,20 @@ public CassandraTable getTable(SchemaTableName schemaTableName) throws TableNotFoundException { KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(schemaTableName.getSchemaName()); - AbstractTableMetadata tableMeta = getTableMetadata(keyspace, schemaTableName.getTableName()); + RelationMetadata tableMeta = getTableMetadata(keyspace, schemaTableName.getTableName()); List columnNames = new ArrayList<>(); - List columns = tableMeta.getColumns(); + Collection columns = tableMeta.getColumns().values(); checkColumnNames(columns); for (ColumnMetadata columnMetadata : columns) { - columnNames.add(columnMetadata.getName()); + columnNames.add(columnMetadata.getName().asInternal()); } // check if there is a comment to establish column ordering - String comment = tableMeta.getOptions().getComment(); + Object comment = tableMeta.getOptions().get(CqlIdentifier.fromInternal("comment")); Set hiddenColumns = ImmutableSet.of(); - if (comment != null && comment.startsWith(PRESTO_COMMENT_METADATA)) { - String columnOrderingString = comment.substring(PRESTO_COMMENT_METADATA.length()); + if (comment instanceof String && ((String) comment).startsWith(PRESTO_COMMENT_METADATA)) { + String columnOrderingString = ((String) comment).substring(PRESTO_COMMENT_METADATA.length()); // column ordering List extras = extraColumnMetadataCodec.fromJson(columnOrderingString); @@ -218,28 +236,28 @@ public CassandraTable getTable(SchemaTableName schemaTableName) ImmutableList.Builder columnHandles = ImmutableList.builder(); // add primary keys first - Set primaryKeySet = new HashSet<>(); + Set primaryKeySet = new HashSet<>(); for (ColumnMetadata columnMeta : tableMeta.getPartitionKey()) { primaryKeySet.add(columnMeta.getName()); - boolean hidden = hiddenColumns.contains(columnMeta.getName()); - CassandraColumnHandle columnHandle = buildColumnHandle(tableMeta, columnMeta, true, false, columnNames.indexOf(columnMeta.getName()), hidden) - .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "Unsupported partition key type: " + columnMeta.getType().getName())); + boolean hidden = hiddenColumns.contains(columnMeta.getName().asInternal()); + CassandraColumnHandle columnHandle = buildColumnHandle(tableMeta, columnMeta, true, false, columnNames.indexOf(columnMeta.getName().asInternal()), hidden) + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "Unsupported partition key type: " + columnMeta.getType().asCql(false, false))); columnHandles.add(columnHandle); } // add clustering columns - for (ColumnMetadata columnMeta : tableMeta.getClusteringColumns()) { + for (ColumnMetadata columnMeta : tableMeta.getClusteringColumns().keySet()) { primaryKeySet.add(columnMeta.getName()); - boolean hidden = hiddenColumns.contains(columnMeta.getName()); - Optional columnHandle = buildColumnHandle(tableMeta, columnMeta, false, true, columnNames.indexOf(columnMeta.getName()), hidden); + boolean hidden = hiddenColumns.contains(columnMeta.getName().asInternal()); + Optional columnHandle = buildColumnHandle(tableMeta, columnMeta, false, true, columnNames.indexOf(columnMeta.getName().asInternal()), hidden); columnHandle.ifPresent(columnHandles::add); } // add other columns for (ColumnMetadata columnMeta : columns) { if (!primaryKeySet.contains(columnMeta.getName())) { - boolean hidden = hiddenColumns.contains(columnMeta.getName()); - Optional columnHandle = buildColumnHandle(tableMeta, columnMeta, false, false, columnNames.indexOf(columnMeta.getName()), hidden); + boolean hidden = hiddenColumns.contains(columnMeta.getName().asInternal()); + Optional columnHandle = buildColumnHandle(tableMeta, columnMeta, false, false, columnNames.indexOf(columnMeta.getName().asInternal()), hidden); columnHandle.ifPresent(columnHandles::add); } } @@ -248,19 +266,21 @@ public CassandraTable getTable(SchemaTableName schemaTableName) .sorted(comparing(CassandraColumnHandle::getOrdinalPosition)) .collect(toList()); - CassandraTableHandle tableHandle = new CassandraTableHandle(tableMeta.getKeyspace().getName(), tableMeta.getName()); + CassandraTableHandle tableHandle = new CassandraTableHandle(tableMeta.getKeyspace().asInternal(), tableMeta.getName().asInternal()); return new CassandraTable(tableHandle, sortedColumnHandles); } private KeyspaceMetadata getKeyspaceByCaseInsensitiveName(String caseInsensitiveSchemaName) throws SchemaNotFoundException { - List keyspaces = executeWithSession(session -> session.getCluster().getMetadata().getKeyspaces()); + Collection keyspaces = executeWithSession(session -> session.getMetadata().getKeyspaces()).values(); KeyspaceMetadata result = null; // Ensure that the error message is deterministic - List sortedKeyspaces = Ordering.from(comparing(KeyspaceMetadata::getName)).immutableSortedCopy(keyspaces); + List sortedKeyspaces = keyspaces.stream() + .sorted(Comparator.comparing(keyspaceMetadata -> keyspaceMetadata.getName().asInternal())) + .collect(toImmutableList()); for (KeyspaceMetadata keyspace : sortedKeyspaces) { - if (keyspace.getName().equalsIgnoreCase(caseInsensitiveSchemaName)) { + if (keyspace.getName().asInternal().equalsIgnoreCase(caseInsensitiveSchemaName)) { if (result != null) { throw new TrinoException( NOT_SUPPORTED, @@ -276,21 +296,21 @@ private KeyspaceMetadata getKeyspaceByCaseInsensitiveName(String caseInsensitive return result; } - private static AbstractTableMetadata getTableMetadata(KeyspaceMetadata keyspace, String caseInsensitiveTableName) + private static RelationMetadata getTableMetadata(KeyspaceMetadata keyspace, String caseInsensitiveTableName) { - List tables = Stream.concat( - keyspace.getTables().stream(), - keyspace.getMaterializedViews().stream()) - .filter(table -> table.getName().equalsIgnoreCase(caseInsensitiveTableName)) + List tables = Stream.concat( + keyspace.getTables().values().stream(), + keyspace.getViews().values().stream()) + .filter(table -> table.getName().asInternal().equalsIgnoreCase(caseInsensitiveTableName)) .collect(toImmutableList()); if (tables.size() == 0) { - throw new TableNotFoundException(new SchemaTableName(keyspace.getName(), caseInsensitiveTableName)); + throw new TableNotFoundException(new SchemaTableName(keyspace.getName().asInternal(), caseInsensitiveTableName)); } if (tables.size() == 1) { return tables.get(0); } String tableNames = tables.stream() - .map(AbstractTableMetadata::getName) + .map(metadata -> metadata.getName().asInternal()) .sorted() .collect(joining(", ")); throw new TrinoException( @@ -302,14 +322,14 @@ private static AbstractTableMetadata getTableMetadata(KeyspaceMetadata keyspace, public boolean isMaterializedView(SchemaTableName schemaTableName) { KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(schemaTableName.getSchemaName()); - return keyspace.getMaterializedView(schemaTableName.getTableName()) != null; + return keyspace.getView(validTableName(schemaTableName.getTableName())).isPresent(); } - private static void checkColumnNames(List columns) + private static void checkColumnNames(Collection columns) { Map lowercaseNameToColumnMap = new HashMap<>(); for (ColumnMetadata column : columns) { - String lowercaseName = column.getName().toLowerCase(ENGLISH); + String lowercaseName = column.getName().asInternal().toLowerCase(ENGLISH); if (lowercaseNameToColumnMap.containsKey(lowercaseName)) { throw new TrinoException( NOT_SUPPORTED, @@ -320,15 +340,15 @@ private static void checkColumnNames(List columns) } } - private Optional buildColumnHandle(AbstractTableMetadata tableMetadata, ColumnMetadata columnMeta, boolean partitionKey, boolean clusteringKey, int ordinalPosition, boolean hidden) + private Optional buildColumnHandle(RelationMetadata tableMetadata, ColumnMetadata columnMeta, boolean partitionKey, boolean clusteringKey, int ordinalPosition, boolean hidden) { Optional cassandraType = toCassandraType(columnMeta.getType()); if (cassandraType.isEmpty()) { - log.debug("Unsupported column type: %s", columnMeta.getType().getName()); + log.debug("Unsupported column type: %s", columnMeta.getType().asCql(false, false)); return Optional.empty(); } - List typeArgs = columnMeta.getType().getTypeArguments(); + List typeArgs = getTypeArguments(columnMeta.getType()); for (DataType typeArgument : typeArgs) { if (!isFullySupported(typeArgument)) { log.debug("%s column has unsupported type: %s", columnMeta.getName(), typeArgument); @@ -336,17 +356,17 @@ private Optional buildColumnHandle(AbstractTableMetadata } } boolean indexed = false; - SchemaTableName schemaTableName = new SchemaTableName(tableMetadata.getKeyspace().getName(), tableMetadata.getName()); + SchemaTableName schemaTableName = new SchemaTableName(tableMetadata.getKeyspace().asInternal(), tableMetadata.getName().asInternal()); if (!isMaterializedView(schemaTableName)) { TableMetadata table = (TableMetadata) tableMetadata; - for (IndexMetadata idx : table.getIndexes()) { - if (idx.getTarget().equals(columnMeta.getName())) { + for (IndexMetadata idx : table.getIndexes().values()) { + if (idx.getTarget().equals(columnMeta.getName().asInternal())) { indexed = true; break; } } } - return Optional.of(new CassandraColumnHandle(columnMeta.getName(), ordinalPosition, cassandraType.get(), partitionKey, clusteringKey, indexed, hidden)); + return Optional.of(new CassandraColumnHandle(columnMeta.getName().asInternal(), ordinalPosition, cassandraType.get(), partitionKey, clusteringKey, indexed, hidden)); } /** @@ -393,7 +413,7 @@ public List getPartitions(CassandraTable table, List session.execute(cql)); } - public PreparedStatement prepare(RegularStatement statement) + public PreparedStatement prepare(SimpleStatement statement) { - log.debug("Execute RegularStatement: %s", statement); + log.debug("Execute SimpleStatement: %s", statement); return executeWithSession(session -> session.prepare(statement)); } @@ -448,11 +468,11 @@ private Iterable queryPartitionKeysWithInClauses(CassandraTable table, List CassandraTableHandle tableHandle = table.getTableHandle(); List partitionKeyColumns = table.getPartitionKeyColumns(); - Select partitionKeys = selectDistinctFrom(tableHandle, partitionKeyColumns); - addWhereInClauses(partitionKeys.where(), partitionKeyColumns, filterPrefixes); + Select partitionKeys = selectDistinctFrom(tableHandle, partitionKeyColumns) + .where(getInRelations(partitionKeyColumns, filterPrefixes)); log.debug("Execute cql for partition keys with IN clauses: %s", partitionKeys); - return execute(partitionKeys).all(); + return execute(partitionKeys.build()).all(); } private Iterable queryPartitionKeysLegacyWithMultipleQueries(CassandraTable table, List> filterPrefixes) @@ -464,11 +484,11 @@ private Iterable queryPartitionKeysLegacyWithMultipleQueries(CassandraTable ImmutableList.Builder rowList = ImmutableList.builder(); for (List combination : filterCombinations) { - Select partitionKeys = selectDistinctFrom(tableHandle, partitionKeyColumns); - addWhereClause(partitionKeys.where(), partitionKeyColumns, combination); + Select partitionKeys = selectDistinctFrom(tableHandle, partitionKeyColumns) + .where(getEqualityRelations(partitionKeyColumns, combination)); log.debug("Execute cql for partition keys with multiple queries: %s", partitionKeys); - List resultRows = execute(partitionKeys).all(); + List resultRows = execute(partitionKeys.build()).all(); if (resultRows != null && !resultRows.isEmpty()) { rowList.addAll(resultRows); } @@ -477,36 +497,45 @@ private Iterable queryPartitionKeysLegacyWithMultipleQueries(CassandraTable return rowList.build(); } - private static void addWhereInClauses(Select.Where where, List partitionKeyColumns, List> filterPrefixes) + private static List getInRelations(List partitionKeyColumns, List> filterPrefixes) { - for (int i = 0; i < filterPrefixes.size(); i++) { - CassandraColumnHandle column = partitionKeyColumns.get(i); - List values = filterPrefixes.get(i) - .stream() - .map(value -> column.getCassandraType().getJavaValue(value)) - .collect(toList()); - Clause clause = QueryBuilder.in(CassandraCqlUtils.validColumnName(column.getName()), values); - where.and(clause); - } + return IntStream + .range(0, Math.min(partitionKeyColumns.size(), filterPrefixes.size())) + .mapToObj(i -> getInRelation(partitionKeyColumns.get(i), filterPrefixes.get(i))) + .collect(toImmutableList()); } - private static void addWhereClause(Select.Where where, List partitionKeyColumns, List filterPrefix) + private static Relation getInRelation(CassandraColumnHandle column, Set filterPrefixes) { - for (int i = 0; i < filterPrefix.size(); i++) { - CassandraColumnHandle column = partitionKeyColumns.get(i); - Object value = column.getCassandraType().getJavaValue(filterPrefix.get(i)); - Clause clause = QueryBuilder.eq(CassandraCqlUtils.validColumnName(column.getName()), value); - where.and(clause); - } + List values = filterPrefixes + .stream() + .map(value -> column.getCassandraType().getJavaValue(value)) + .map(QueryBuilder::literal) + .collect(toList()); + + return Relation.column(CassandraCqlUtils.validColumnName(column.getName())).in(values); + } + + private static List getEqualityRelations(List partitionKeyColumns, List filterPrefix) + { + return IntStream + .range(0, Math.min(partitionKeyColumns.size(), filterPrefix.size())) + .mapToObj(i -> { + CassandraColumnHandle column = partitionKeyColumns.get(i); + Object value = column.getCassandraType().getJavaValue(filterPrefix.get(i)); + return Relation.column(CassandraCqlUtils.validColumnName(column.getName())).isEqualTo(literal(value)); + }) + .collect(toImmutableList()); } public List getSizeEstimates(String keyspaceName, String tableName) { checkSizeEstimatesTableExist(); - Statement statement = select("partitions_count") - .from(SYSTEM, SIZE_ESTIMATES) - .where(eq("keyspace_name", keyspaceName)) - .and(eq("table_name", tableName)); + SimpleStatement statement = selectFrom(SYSTEM, SIZE_ESTIMATES) + .column("partitions_count") + .where(Relation.column("keyspace_name").isEqualTo(literal(keyspaceName)), + Relation.column("table_name").isEqualTo(literal(tableName))) + .build(); ResultSet result = executeWithSession(session -> session.execute(statement)); ImmutableList.Builder estimates = ImmutableList.builder(); @@ -520,31 +549,31 @@ public List getSizeEstimates(String keyspaceName, String tableName private void checkSizeEstimatesTableExist() { - KeyspaceMetadata keyspaceMetadata = executeWithSession(session -> session.getCluster().getMetadata().getKeyspace(SYSTEM)); - checkState(keyspaceMetadata != null, "system keyspace metadata must not be null"); - TableMetadata table = keyspaceMetadata.getTable(SIZE_ESTIMATES); - if (table == null) { + Optional keyspaceMetadata = executeWithSession(session -> session.getMetadata().getKeyspace(SYSTEM)); + checkState(keyspaceMetadata.isPresent(), "system keyspace metadata must not be null"); + Optional sizeEstimatesTableMetadata = keyspaceMetadata.flatMap(metadata -> metadata.getTable(SIZE_ESTIMATES)); + if (sizeEstimatesTableMetadata.isEmpty()) { throw new TrinoException(NOT_SUPPORTED, "Cassandra versions prior to 2.1.5 are not supported"); } } private T executeWithSession(SessionCallable sessionCallable) { - ReconnectionPolicy reconnectionPolicy = cluster.getConfiguration().getPolicies().getReconnectionPolicy(); - ReconnectionSchedule schedule = reconnectionPolicy.newSchedule(); + ReconnectionPolicy reconnectionPolicy = session.get().getContext().getReconnectionPolicy(); + ReconnectionPolicy.ReconnectionSchedule schedule = reconnectionPolicy.newControlConnectionSchedule(false); long deadline = System.currentTimeMillis() + noHostAvailableRetryTimeout.toMillis(); while (true) { try { return sessionCallable.executeWithSession(session.get()); } - catch (NoHostAvailableException e) { + catch (AllNodesFailedException e) { long timeLeft = deadline - System.currentTimeMillis(); if (timeLeft <= 0) { throw e; } else { - long delay = Math.min(schedule.nextDelayMs(), timeLeft); - log.warn("%s", e.getCustomMessage(10, true, true)); + long delay = Math.min(schedule.nextDelay().toMillis(), timeLeft); + log.warn(e.getMessage()); log.warn("Reconnecting in %dms", delay); try { Thread.sleep(delay); @@ -558,8 +587,40 @@ private T executeWithSession(SessionCallable sessionCallable) } } + private List getTypeArguments(DataType dataType) + { + if (dataType instanceof UserDefinedType) { + return ImmutableList.copyOf(((UserDefinedType) dataType).getFieldTypes()); + } + + if (dataType instanceof MapType) { + MapType mapType = (MapType) dataType; + return ImmutableList.of(mapType.getKeyType(), mapType.getValueType()); + } + + if (dataType instanceof ListType) { + return ImmutableList.of(((ListType) dataType).getElementType()); + } + + if (dataType instanceof TupleType) { + return ImmutableList.copyOf(((TupleType) dataType).getComponentTypes()); + } + + if (dataType instanceof SetType) { + return ImmutableList.of(((SetType) dataType).getElementType()); + } + + return ImmutableList.of(); + } + + @Override + public void close() + { + session.get().close(); + } + private interface SessionCallable { - T executeWithSession(Session session); + T executeWithSession(CqlSession session); } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java index 2d601b4db3de..4f94df8432ef 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java @@ -13,7 +13,12 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.Host; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.token.TokenRange; +import com.datastax.oss.driver.internal.core.metadata.token.Murmur3Token; +import com.datastax.oss.driver.internal.core.metadata.token.Murmur3TokenRange; +import com.datastax.oss.driver.internal.core.metadata.token.RandomToken; +import com.datastax.oss.driver.internal.core.metadata.token.RandomTokenRange; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.airlift.log.Logger; @@ -130,7 +135,7 @@ private List getSplitsByTokenRange(CassandraTable table, String ImmutableList.Builder builder = ImmutableList.builder(); List tokenSplits = tokenSplitMgr.getSplits(schema, tableName, sessionSplitsPerNode); for (CassandraTokenSplitManager.TokenSplit tokenSplit : tokenSplits) { - String condition = buildTokenCondition(tokenExpression, tokenSplit.getStartToken(), tokenSplit.getEndToken()); + String condition = buildTokenCondition(tokenExpression, tokenSplit.getTokenRange()); List addresses = new HostAddressFactory().hostAddressNamesToHostAddressList(tokenSplit.getHosts()); CassandraSplit split = new CassandraSplit(partitionId, condition, addresses); builder.add(split); @@ -139,9 +144,24 @@ private List getSplitsByTokenRange(CassandraTable table, String return builder.build(); } - private static String buildTokenCondition(String tokenExpression, String startToken, String endToken) + private static String buildTokenCondition(String tokenExpression, TokenRange tokenRange) { - return tokenExpression + " > " + startToken + " AND " + tokenExpression + " <= " + endToken; + Number startTokenValue; + Number endTokenValue; + if (tokenRange instanceof Murmur3TokenRange) { + Murmur3TokenRange murmur3TokenRange = (Murmur3TokenRange) tokenRange; + startTokenValue = ((Murmur3Token) murmur3TokenRange.getStart()).getValue(); + endTokenValue = ((Murmur3Token) murmur3TokenRange.getEnd()).getValue(); + } + else if (tokenRange instanceof RandomTokenRange) { + RandomTokenRange randomTokenRange = (RandomTokenRange) tokenRange; + startTokenValue = ((RandomToken) randomTokenRange.getStart()).getValue(); + endTokenValue = ((RandomToken) randomTokenRange.getEnd()).getValue(); + } + else { + throw new IllegalStateException(format("Unsupported token range class %s", tokenRange.getClass().getName())); + } + return tokenExpression + " > " + startTokenValue + " AND " + tokenExpression + " <= " + endTokenValue; } private List getSplitsForPartitions(CassandraTableHandle cassTableHandle, List partitions, String clusteringPredicates) @@ -167,8 +187,8 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab Map, List> hostMap = new HashMap<>(); for (CassandraPartition cassandraPartition : partitions) { - Set hosts = cassandraSession.getReplicas(schema, cassandraPartition.getKeyAsByteBuffer()); - List addresses = hostAddressFactory.toHostAddressList(hosts); + Set nodes = cassandraSession.getReplicas(schema, cassandraPartition.getKeyAsByteBuffer()); + List addresses = hostAddressFactory.toHostAddressList(nodes); if (singlePartitionKeyColumn) { // host ip addresses ImmutableSet.Builder sb = ImmutableSet.builder(); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTokenSplitManager.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTokenSplitManager.java index d9e22328076c..fb00c66f46ae 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTokenSplitManager.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTokenSplitManager.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.Host; -import com.datastax.driver.core.TokenRange; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.token.TokenRange; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.trino.spi.TrinoException; @@ -92,7 +92,7 @@ public List getSplits(String keyspace, String table, Optional continue; } - double tokenRangeRingFraction = tokenRing.get().getRingFraction(tokenRange.getStart().toString(), tokenRange.getEnd().toString()); + double tokenRangeRingFraction = tokenRing.get().getRingFraction(tokenRange.getStart(), tokenRange.getEnd()); long partitionsCountEstimate = round(totalPartitionsCount * tokenRangeRingFraction); checkState(partitionsCountEstimate >= 0, "unexpected partitions count estimate: %s", partitionsCountEstimate); int subSplitCount = max(toIntExact(partitionsCountEstimate / splitSize), 1); @@ -135,41 +135,35 @@ public long getTotalPartitionsCount(String keyspace, String table, Optional getEndpoints(String keyspace, TokenRange tokenRange) { - Set endpoints = session.getReplicas(keyspace, tokenRange); + Set endpoints = session.getReplicas(keyspace, tokenRange); + return unmodifiableList(endpoints.stream() - .map(Host::toString) + .map(endpoint -> endpoint.getEndPoint().resolve().toString()) .collect(toList())); } private static TokenSplit createSplit(TokenRange range, List endpoints) { checkArgument(!range.isEmpty(), "tokenRange must not be empty"); - String startToken = range.getStart().toString(); - String endToken = range.getEnd().toString(); - return new TokenSplit(startToken, endToken, endpoints); + requireNonNull(range.getStart(), "tokenRange.start is null"); + requireNonNull(range.getEnd(), "tokenRange.end is null"); + return new TokenSplit(range, endpoints); } public static class TokenSplit { - private String startToken; - private String endToken; + private TokenRange tokenRange; private List hosts; - public TokenSplit(String startToken, String endToken, List hosts) + public TokenSplit(TokenRange tokenRange, List hosts) { - this.startToken = requireNonNull(startToken, "startToken is null"); - this.endToken = requireNonNull(endToken, "endToken is null"); + this.tokenRange = requireNonNull(tokenRange, "tokenRange is null"); this.hosts = ImmutableList.copyOf(requireNonNull(hosts, "hosts is null")); } - public String getStartToken() - { - return startToken; - } - - public String getEndToken() + public TokenRange getTokenRange() { - return endToken; + return tokenRange; } public List getHosts() @@ -181,8 +175,8 @@ public List getHosts() public String toString() { return toStringHelper(this) - .add("startToken", startToken) - .add("endToken", endToken) + .add("startToken", tokenRange.getStart()) + .add("endToken", tokenRange.getEnd()) .add("hosts", hosts) .toString(); } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java index 663188809387..d16b00c7dcaa 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java @@ -13,16 +13,20 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.DataType; -import com.datastax.driver.core.GettableByIndexData; -import com.datastax.driver.core.LocalDate; -import com.datastax.driver.core.ProtocolVersion; -import com.datastax.driver.core.Row; -import com.datastax.driver.core.TupleType; -import com.datastax.driver.core.TupleValue; -import com.datastax.driver.core.UDTValue; -import com.datastax.driver.core.UserType; -import com.datastax.driver.core.utils.Bytes; +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.data.GettableByIndex; +import com.datastax.oss.driver.api.core.data.TupleValue; +import com.datastax.oss.driver.api.core.data.UdtValue; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.ListType; +import com.datastax.oss.driver.api.core.type.MapType; +import com.datastax.oss.driver.api.core.type.SetType; +import com.datastax.oss.driver.api.core.type.TupleType; +import com.datastax.oss.driver.api.core.type.UserDefinedType; +import com.datastax.oss.protocol.internal.ProtocolConstants; +import com.datastax.oss.protocol.internal.util.Bytes; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; @@ -54,8 +58,10 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.time.Instant; +import java.time.LocalDate; +import java.util.Arrays; import java.util.Collection; -import java.util.Date; import java.util.List; import java.util.Map; import java.util.Objects; @@ -65,12 +71,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.net.InetAddresses.toAddrString; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteral; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteralForJson; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; @@ -161,56 +167,54 @@ public String getName() public static Optional toCassandraType(DataType dataType) { - switch (dataType.getName()) { - case ASCII: + switch (dataType.getProtocolCode()) { + case ProtocolConstants.DataType.ASCII: return Optional.of(CassandraTypes.ASCII); - case BIGINT: + case ProtocolConstants.DataType.BIGINT: return Optional.of(CassandraTypes.BIGINT); - case BLOB: + case ProtocolConstants.DataType.BLOB: return Optional.of(CassandraTypes.BLOB); - case BOOLEAN: + case ProtocolConstants.DataType.BOOLEAN: return Optional.of(CassandraTypes.BOOLEAN); - case COUNTER: + case ProtocolConstants.DataType.COUNTER: return Optional.of(CassandraTypes.COUNTER); - case CUSTOM: + case ProtocolConstants.DataType.CUSTOM: return Optional.of(CassandraTypes.CUSTOM); - case DATE: + case ProtocolConstants.DataType.DATE: return Optional.of(CassandraTypes.DATE); - case DECIMAL: + case ProtocolConstants.DataType.DECIMAL: return Optional.of(CassandraTypes.DECIMAL); - case DOUBLE: + case ProtocolConstants.DataType.DOUBLE: return Optional.of(CassandraTypes.DOUBLE); - case FLOAT: + case ProtocolConstants.DataType.FLOAT: return Optional.of(CassandraTypes.FLOAT); - case INET: + case ProtocolConstants.DataType.INET: return Optional.of(CassandraTypes.INET); - case INT: + case ProtocolConstants.DataType.INT: return Optional.of(CassandraTypes.INT); - case LIST: + case ProtocolConstants.DataType.LIST: return Optional.of(CassandraTypes.LIST); - case MAP: + case ProtocolConstants.DataType.MAP: return Optional.of(CassandraTypes.MAP); - case SET: + case ProtocolConstants.DataType.SET: return Optional.of(CassandraTypes.SET); - case SMALLINT: + case ProtocolConstants.DataType.SMALLINT: return Optional.of(CassandraTypes.SMALLINT); - case TEXT: - return Optional.of(CassandraTypes.TEXT); - case TIMESTAMP: + case ProtocolConstants.DataType.TIMESTAMP: return Optional.of(CassandraTypes.TIMESTAMP); - case TIMEUUID: + case ProtocolConstants.DataType.TIMEUUID: return Optional.of(CassandraTypes.TIMEUUID); - case TINYINT: + case ProtocolConstants.DataType.TINYINT: return Optional.of(CassandraTypes.TINYINT); - case TUPLE: + case ProtocolConstants.DataType.TUPLE: return createTypeForTuple(dataType); - case UDT: + case ProtocolConstants.DataType.UDT: return createTypeForUserType(dataType); - case UUID: + case ProtocolConstants.DataType.UUID: return Optional.of(CassandraTypes.UUID); - case VARCHAR: + case ProtocolConstants.DataType.VARCHAR: return Optional.of(CassandraTypes.VARCHAR); - case VARINT: + case ProtocolConstants.DataType.VARINT: return Optional.of(CassandraTypes.VARINT); default: return Optional.empty(); @@ -242,16 +246,21 @@ private static Optional createTypeForTuple(DataType dataType) private static Optional createTypeForUserType(DataType dataType) { - UserType userType = (UserType) dataType; + UserDefinedType userDefinedType = (UserDefinedType) dataType; // Using ImmutableMap is important as we exploit the fact that entries iteration order matches the order of putting values via builder ImmutableMap.Builder argumentTypes = ImmutableMap.builder(); - for (UserType.Field field : userType) { - Optional cassandraType = CassandraType.toCassandraType(field.getType()); + + List fieldNames = userDefinedType.getFieldNames(); + List fieldTypes = userDefinedType.getFieldTypes(); + if (fieldNames.size() != fieldTypes.size()) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Mismatch between the number of field names (%s) and the number of field types (%s) for the data type %s", fieldNames.size(), fieldTypes.size(), dataType)); + } + for (int i = 0; i < fieldNames.size(); i++) { + Optional cassandraType = CassandraType.toCassandraType(fieldTypes.get(i)); if (cassandraType.isEmpty()) { return Optional.empty(); } - - argumentTypes.put(field.getName(), cassandraType.get()); + argumentTypes.put(fieldNames.get(i).toString(), cassandraType.get()); } RowType trinoType = RowType.from( @@ -264,10 +273,10 @@ private static Optional createTypeForUserType(DataType dataType) public NullableValue getColumnValue(Row row, int position) { - return getColumnValue(row, position, () -> row.getColumnDefinitions().getType(position)); + return getColumnValue(row, position, () -> row.getColumnDefinitions().get(position).getType()); } - public NullableValue getColumnValue(GettableByIndexData row, int position, Supplier dataTypeSupplier) + public NullableValue getColumnValue(GettableByIndex row, int position, Supplier dataTypeSupplier) { if (row.isNull(position)) { return NullableValue.asNull(trinoType); @@ -288,30 +297,31 @@ public NullableValue getColumnValue(GettableByIndexData row, int position, Suppl case COUNTER: return NullableValue.of(trinoType, row.getLong(position)); case BOOLEAN: - return NullableValue.of(trinoType, row.getBool(position)); + return NullableValue.of(trinoType, row.getBoolean(position)); case DOUBLE: return NullableValue.of(trinoType, row.getDouble(position)); case FLOAT: return NullableValue.of(trinoType, (long) floatToRawIntBits(row.getFloat(position))); case DECIMAL: - return NullableValue.of(trinoType, row.getDecimal(position).doubleValue()); + return NullableValue.of(trinoType, row.getBigDecimal(position).doubleValue()); case UUID: case TIMEUUID: - return NullableValue.of(trinoType, javaUuidToTrinoUuid(row.getUUID(position))); + return NullableValue.of(trinoType, javaUuidToTrinoUuid(row.getUuid(position))); case TIMESTAMP: - return NullableValue.of(trinoType, packDateTimeWithZone(row.getTimestamp(position).getTime(), TimeZoneKey.UTC_KEY)); + return NullableValue.of(trinoType, packDateTimeWithZone(row.getInstant(position).toEpochMilli(), TimeZoneKey.UTC_KEY)); case DATE: - return NullableValue.of(trinoType, (long) row.getDate(position).getDaysSinceEpoch()); + return NullableValue.of(trinoType, row.getLocalDate(position).toEpochDay()); case INET: - return NullableValue.of(trinoType, utf8Slice(toAddrString(row.getInet(position)))); + return NullableValue.of(trinoType, utf8Slice(toAddrString(row.getInetAddress(position)))); case VARINT: - return NullableValue.of(trinoType, utf8Slice(row.getVarint(position).toString())); + return NullableValue.of(trinoType, utf8Slice(row.getBigInteger(position).toString())); case BLOB: case CUSTOM: return NullableValue.of(trinoType, wrappedBuffer(row.getBytesUnsafe(position))); case SET: + return NullableValue.of(trinoType, utf8Slice(buildArrayValueFromSetType(row, position, dataTypeSupplier.get()))); case LIST: - return NullableValue.of(trinoType, utf8Slice(buildArrayValue(row, position, dataTypeSupplier.get()))); + return NullableValue.of(trinoType, utf8Slice(buildArrayValueFromListType(row, position, dataTypeSupplier.get()))); case MAP: return NullableValue.of(trinoType, utf8Slice(buildMapValue(row, position, dataTypeSupplier.get()))); case TUPLE: @@ -322,12 +332,11 @@ public NullableValue getColumnValue(GettableByIndexData row, int position, Suppl throw new IllegalStateException("Handling of type " + this + " is not implemented"); } - private static String buildMapValue(GettableByIndexData row, int position, DataType dataType) + private static String buildMapValue(GettableByIndex row, int position, DataType dataType) { - checkArgument(dataType.getTypeArguments().size() == 2, "Expected two type arguments, got: %s", dataType.getTypeArguments()); - DataType keyType = dataType.getTypeArguments().get(0); - DataType valueType = dataType.getTypeArguments().get(1); - return buildMapValue((Map) row.getObject(position), keyType, valueType); + checkArgument(dataType instanceof MapType, "Expected to deal with an instance of %s class, got: %s", MapType.class, dataType); + MapType mapType = (MapType) dataType; + return buildMapValue((Map) row.getObject(position), mapType.getKeyType(), mapType.getValueType()); } private static String buildMapValue(Map cassandraMap, DataType keyType, DataType valueType) @@ -346,10 +355,18 @@ private static String buildMapValue(Map cassandraMap, DataType keyType, Da return sb.toString(); } - private static String buildArrayValue(GettableByIndexData row, int position, DataType dataType) + private static String buildArrayValueFromSetType(GettableByIndex row, int position, DataType type) + { + checkArgument(type instanceof SetType, "Expected to deal with an instance of %s class, got: %s", SetType.class, type); + SetType setType = (SetType) type; + return buildArrayValue((Collection) row.getObject(position), setType.getElementType()); + } + + private static String buildArrayValueFromListType(GettableByIndex row, int position, DataType type) { - DataType elementType = getOnlyElement(dataType.getTypeArguments()); - return buildArrayValue((Collection) row.getObject(position), elementType); + checkArgument(type instanceof ListType, "Expected to deal with an instance of %s class, got: %s", ListType.class, type); + ListType listType = (ListType) type; + return buildArrayValue((Collection) row.getObject(position), listType.getElementType()); } @VisibleForTesting @@ -367,7 +384,7 @@ static String buildArrayValue(Collection cassandraCollection, DataType elemen return sb.toString(); } - private Block buildTupleValue(GettableByIndexData row, int position) + private Block buildTupleValue(GettableByIndex row, int position) { verify(this.kind == Kind.TUPLE, "Not a TUPLE type"); TupleValue tupleValue = row.getTupleValue(position); @@ -385,17 +402,17 @@ private Block buildTupleValue(GettableByIndexData row, int position) return (Block) this.trinoType.getObject(blockBuilder, 0); } - private Block buildUserTypeValue(GettableByIndexData row, int position) + private Block buildUserTypeValue(GettableByIndex row, int position) { verify(this.kind == Kind.UDT, "Not a user defined type: %s", this.kind); - UDTValue udtValue = row.getUDTValue(position); - String[] fieldNames = udtValue.getType().getFieldNames().toArray(String[]::new); + UdtValue udtValue = row.getUdtValue(position); RowBlockBuilder blockBuilder = (RowBlockBuilder) this.trinoType.createBlockBuilder(null, 1); SingleRowBlockWriter singleRowBlockWriter = blockBuilder.beginBlockEntry(); int tuplePosition = 0; + List udtTypeFieldTypes = udtValue.getType().getFieldTypes(); for (CassandraType argumentType : this.getArgumentTypes()) { int finalTuplePosition = tuplePosition; - NullableValue value = argumentType.getColumnValue(udtValue, tuplePosition, () -> udtValue.getType().getFieldType(fieldNames[finalTuplePosition])); + NullableValue value = argumentType.getColumnValue(udtValue, tuplePosition, () -> udtTypeFieldTypes.get(finalTuplePosition)); writeNativeValue(argumentType.getTrinoType(), singleRowBlockWriter, value.getValue()); tuplePosition++; } @@ -432,18 +449,18 @@ public String getColumnValueForCql(Row row, int position) case FLOAT: return Float.toString(row.getFloat(position)); case DECIMAL: - return row.getDecimal(position).toString(); + return row.getBigDecimal(position).toString(); case UUID: case TIMEUUID: - return row.getUUID(position).toString(); + return row.getUuid(position).toString(); case TIMESTAMP: - return Long.toString(row.getTimestamp(position).getTime()); + return Long.toString(row.getInstant(position).toEpochMilli()); case DATE: - return quoteStringLiteral(row.getDate(position).toString()); + return quoteStringLiteral(row.getLocalDate(position).toString()); case INET: - return quoteStringLiteral(toAddrString(row.getInet(position))); + return quoteStringLiteral(toAddrString(row.getInetAddress(position))); case VARINT: - return row.getVarint(position).toString(); + return row.getBigInteger(position).toString(); case BLOB: case CUSTOM: return Bytes.toHexString(row.getBytesUnsafe(position)); @@ -463,7 +480,7 @@ public String getColumnValueForCql(Row row, int position) public String toCqlLiteral(Object trinoNativeValue) { if (kind == Kind.DATE) { - LocalDate date = LocalDate.fromDaysSinceEpoch(toIntExact((long) trinoNativeValue)); + LocalDate date = LocalDate.ofEpochDay(toIntExact((long) trinoNativeValue)); return quoteStringLiteral(date.toString()); } if (kind == Kind.TIMESTAMP) { @@ -525,10 +542,17 @@ private static String objectToJson(Object cassandraValue, DataType dataType) case DECIMAL: return cassandraValue.toString(); case LIST: + checkArgument(dataType instanceof ListType, "Expected to deal with an instance of %s class, got: %s", ListType.class, dataType); + ListType listType = (ListType) dataType; + return buildArrayValue((Collection) cassandraValue, listType.getElementType()); case SET: - return buildArrayValue((Collection) cassandraValue, getOnlyElement(dataType.getTypeArguments())); + checkArgument(dataType instanceof SetType, "Expected to deal with an instance of %s class, got: %s", SetType.class, dataType); + SetType setType = (SetType) dataType; + return buildArrayValue((Collection) cassandraValue, setType.getElementType()); case MAP: - return buildMapValue((Map) cassandraValue, dataType.getTypeArguments().get(0), dataType.getTypeArguments().get(1)); + checkArgument(dataType instanceof MapType, "Expected to deal with an instance of %s class, got: %s", MapType.class, dataType); + MapType mapType = (MapType) dataType; + return buildMapValue((Map) cassandraValue, mapType.getKeyType(), mapType.getValueType()); } throw new IllegalStateException("Unsupported type: " + cassandraType); } @@ -560,9 +584,9 @@ public Object getJavaValue(Object trinoNativeValue) // Otherwise partition id doesn't match return new BigDecimal(trinoNativeValue.toString()); case TIMESTAMP: - return new Date(unpackMillisUtc((Long) trinoNativeValue)); + return Instant.ofEpochMilli(unpackMillisUtc((Long) trinoNativeValue)); case DATE: - return LocalDate.fromDaysSinceEpoch(((Long) trinoNativeValue).intValue()); + return LocalDate.ofEpochDay(((Long) trinoNativeValue).intValue()); case UUID: case TIMEUUID: return trinoUuidToJavaUuid((Slice) trinoNativeValue); @@ -620,8 +644,31 @@ public static boolean isFullySupported(DataType dataType) return false; } - return dataType.getTypeArguments().stream() - .allMatch(CassandraType::isFullySupported); + if (dataType instanceof UserDefinedType) { + return ((UserDefinedType) dataType).getFieldTypes().stream() + .allMatch(CassandraType::isFullySupported); + } + + if (dataType instanceof MapType) { + MapType mapType = (MapType) dataType; + return Arrays.stream(new DataType[] {mapType.getKeyType(), mapType.getValueType()}) + .allMatch(CassandraType::isFullySupported); + } + + if (dataType instanceof ListType) { + return CassandraType.isFullySupported(((ListType) dataType).getElementType()); + } + + if (dataType instanceof TupleType) { + return ((TupleType) dataType).getComponentTypes().stream() + .allMatch(CassandraType::isFullySupported); + } + + if (dataType instanceof SetType) { + return CassandraType.isFullySupported(((SetType) dataType).getElementType()); + } + + return true; } public static CassandraType toCassandraType(Type type, ProtocolVersion protocolVersion) @@ -651,7 +698,7 @@ public static CassandraType toCassandraType(Type type, ProtocolVersion protocolV return CassandraTypes.TEXT; } if (type.equals(DateType.DATE)) { - return protocolVersion.toInt() <= ProtocolVersion.V3.toInt() + return protocolVersion.getCode() <= ProtocolVersion.V3.getCode() ? CassandraTypes.TEXT : CassandraTypes.DATE; } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/FallthroughRetryPolicy.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/FallthroughRetryPolicy.java new file mode 100644 index 000000000000..daad9a084fb5 --- /dev/null +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/FallthroughRetryPolicy.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.plugin.cassandra; + +import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.context.DriverContext; +import com.datastax.oss.driver.api.core.retry.RetryDecision; +import com.datastax.oss.driver.api.core.retry.RetryPolicy; +import com.datastax.oss.driver.api.core.servererrors.CoordinatorException; +import com.datastax.oss.driver.api.core.servererrors.WriteType; +import com.datastax.oss.driver.api.core.session.Request; + +/** + * A retry policy that never retries (nor ignores). + *

+ * All of the methods of this retry policy unconditionally return {@link RetryDecision#RETHROW}. + * If this policy is used, retry logic will have to be implemented in business code. + */ +public class FallthroughRetryPolicy + implements RetryPolicy +{ + // Required by Cassandra driver library for instantiation + public FallthroughRetryPolicy(DriverContext context, String profileName) {} + + @Override + public RetryDecision onReadTimeout(Request request, ConsistencyLevel consistencyLevel, int blockFor, int received, boolean dataPresent, int retryCount) + { + return RetryDecision.RETHROW; + } + + @Override + public RetryDecision onWriteTimeout(Request request, ConsistencyLevel consistencyLevel, WriteType writeType, int blockFor, int received, int retryCount) + { + return RetryDecision.RETHROW; + } + + @Override + public RetryDecision onUnavailable(Request request, ConsistencyLevel consistencyLevel, int required, int alive, int retries) + { + return RetryDecision.RETHROW; + } + + @Override + public RetryDecision onRequestAborted(Request request, Throwable error, int retryCount) + { + return RetryDecision.RETHROW; + } + + @Override + public RetryDecision onErrorResponse(Request request, CoordinatorException error, int retryCount) + { + return RetryDecision.RETHROW; + } + + @Override + public void close() {} +} diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java index 48cbdcf8db4f..42228051b096 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/Murmur3PartitionerTokenRing.java @@ -13,6 +13,9 @@ */ package io.trino.plugin.cassandra; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.internal.core.metadata.token.Murmur3Token; + import java.math.BigInteger; import static java.math.BigInteger.ZERO; @@ -29,16 +32,16 @@ public final class Murmur3PartitionerTokenRing private Murmur3PartitionerTokenRing() {} @Override - public double getRingFraction(String start, String end) + public double getRingFraction(Token start, Token end) { return getTokenCountInRange(start, end).doubleValue() / TOTAL_TOKEN_COUNT.doubleValue(); } @Override - public BigInteger getTokenCountInRange(String startToken, String endToken) + public BigInteger getTokenCountInRange(Token startToken, Token endToken) { - long start = Long.parseLong(startToken); - long end = Long.parseLong(endToken); + long start = ((Murmur3Token) startToken).getValue(); + long end = ((Murmur3Token) endToken).getValue(); if (start == end) { if (start == MIN_TOKEN) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java index 6e23875bd097..ef8d5acdcaef 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RandomPartitionerTokenRing.java @@ -13,6 +13,9 @@ */ package io.trino.plugin.cassandra; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.internal.core.metadata.token.RandomToken; + import java.math.BigInteger; import static com.google.common.base.Preconditions.checkArgument; @@ -30,17 +33,17 @@ public final class RandomPartitionerTokenRing private RandomPartitionerTokenRing() {} @Override - public double getRingFraction(String start, String end) + public double getRingFraction(Token start, Token end) { return getTokenCountInRange(start, end).doubleValue() / TOTAL_TOKEN_COUNT.doubleValue(); } @Override - public BigInteger getTokenCountInRange(String startToken, String endToken) + public BigInteger getTokenCountInRange(Token startToken, Token endToken) { - BigInteger start = new BigInteger(startToken); + BigInteger start = ((RandomToken) startToken).getValue(); checkTokenBounds(start); - BigInteger end = new BigInteger(endToken); + BigInteger end = ((RandomToken) endToken).getValue(); checkTokenBounds(end); if (start.equals(end)) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ReopeningCluster.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ReopeningCluster.java deleted file mode 100644 index 287fdd837f05..000000000000 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ReopeningCluster.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.cassandra; - -import com.datastax.driver.core.CloseFuture; -import com.datastax.driver.core.Cluster; -import com.datastax.driver.core.DelegatingCluster; -import io.airlift.log.Logger; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - -import java.util.function.Supplier; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static java.util.Objects.requireNonNull; - -@ThreadSafe -public class ReopeningCluster - extends DelegatingCluster -{ - private static final Logger log = Logger.get(ReopeningCluster.class); - - @GuardedBy("this") - private Cluster delegate; - @GuardedBy("this") - private boolean closed; - - private final Supplier supplier; - - public ReopeningCluster(Supplier supplier) - { - this.supplier = requireNonNull(supplier, "supplier is null"); - } - - @Override - protected synchronized Cluster delegate() - { - checkState(!closed, "Cluster has been closed"); - - if (delegate == null) { - delegate = supplier.get(); - } - - if (delegate.isClosed()) { - log.warn("Cluster has been closed internally"); - delegate = supplier.get(); - } - - verify(!delegate.isClosed(), "Newly created cluster has been immediately closed"); - - return delegate; - } - - @Override - public synchronized void close() - { - closed = true; - if (delegate != null) { - delegate.close(); - delegate = null; - } - } - - @Override - public synchronized boolean isClosed() - { - return closed; - } - - @Override - public synchronized CloseFuture closeAsync() - { - throw new UnsupportedOperationException(); - } -} diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RetryPolicyType.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RetryPolicyType.java index 8f6ecaee0c06..ff9af5b26caf 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RetryPolicyType.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/RetryPolicyType.java @@ -13,29 +13,28 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.policies.DefaultRetryPolicy; -import com.datastax.driver.core.policies.DowngradingConsistencyRetryPolicy; -import com.datastax.driver.core.policies.FallthroughRetryPolicy; -import com.datastax.driver.core.policies.RetryPolicy; +import com.datastax.oss.driver.api.core.retry.RetryPolicy; +import com.datastax.oss.driver.internal.core.retry.ConsistencyDowngradingRetryPolicy; +import com.datastax.oss.driver.internal.core.retry.DefaultRetryPolicy; import static java.util.Objects.requireNonNull; public enum RetryPolicyType { - DEFAULT(DefaultRetryPolicy.INSTANCE), - BACKOFF(BackoffRetryPolicy.INSTANCE), - DOWNGRADING_CONSISTENCY(DowngradingConsistencyRetryPolicy.INSTANCE), - FALLTHROUGH(FallthroughRetryPolicy.INSTANCE); + DEFAULT(DefaultRetryPolicy.class), + BACKOFF(BackoffRetryPolicy.class), + DOWNGRADING_CONSISTENCY(ConsistencyDowngradingRetryPolicy.class), + FALLTHROUGH(FallthroughRetryPolicy.class); - private final RetryPolicy policy; + private final Class policyClass; - RetryPolicyType(RetryPolicy policy) + RetryPolicyType(Class policyClass) { - this.policy = requireNonNull(policy, "policy is null"); + this.policyClass = requireNonNull(policyClass, "policyClass is null"); } - public RetryPolicy getPolicy() + public Class getPolicyClass() { - return policy; + return policyClass; } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/TokenRing.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/TokenRing.java index 73fddab0fb5a..c5fb22c1f0f8 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/TokenRing.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/TokenRing.java @@ -13,6 +13,8 @@ */ package io.trino.plugin.cassandra; +import com.datastax.oss.driver.api.core.metadata.token.Token; + import java.math.BigInteger; import java.util.Optional; @@ -26,12 +28,12 @@ public interface TokenRing * @param startToken exclusive * @param endToken inclusive */ - double getRingFraction(String startToken, String endToken); + double getRingFraction(Token startToken, Token endToken); /** * Returns token count in a given range */ - BigInteger getTokenCountInRange(String startToken, String endToken); + BigInteger getTokenCountInRange(Token startToken, Token endToken); static Optional createForPartitioner(String partitioner) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/CassandraCqlUtils.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/CassandraCqlUtils.java index 726cf2ce5a46..175202759cf1 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/CassandraCqlUtils.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/CassandraCqlUtils.java @@ -13,9 +13,9 @@ */ package io.trino.plugin.cassandra.util; -import com.datastax.driver.core.querybuilder.QueryBuilder; -import com.datastax.driver.core.querybuilder.Select; -import com.datastax.driver.core.querybuilder.Select.Selection; +import com.datastax.oss.driver.api.querybuilder.QueryBuilder; +import com.datastax.oss.driver.api.querybuilder.select.Select; +import com.datastax.oss.driver.api.querybuilder.select.SelectFrom; import com.fasterxml.jackson.core.io.JsonStringEncoder; import com.google.common.collect.ImmutableList; import io.trino.plugin.cassandra.CassandraColumnHandle; @@ -26,7 +26,8 @@ import java.util.ArrayList; import java.util.List; -import static com.datastax.driver.core.Metadata.quote; +import static com.datastax.oss.driver.internal.core.util.Strings.doubleQuote; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; @@ -39,12 +40,12 @@ private CassandraCqlUtils() {} public static String validSchemaName(String identifier) { - return quote(identifier); + return doubleQuote(identifier); } public static String validTableName(String identifier) { - return quote(identifier); + return doubleQuote(identifier); } public static String validColumnName(String identifier) @@ -53,7 +54,7 @@ public static String validColumnName(String identifier) return "\"\""; } - return quote(identifier); + return doubleQuote(identifier); } public static String quoteStringLiteral(String string) @@ -100,30 +101,33 @@ public static String sqlNameToCqlName(String name) return name; } - public static Selection select(List columns) + public static List selection(List columns) { - Selection selection = QueryBuilder.select(); - for (CassandraColumnHandle column : columns) { - selection.column(validColumnName(column.getName())); - } - return selection; + return columns.stream() + .map(column -> validColumnName(column.getName())) + .collect(toImmutableList()); } public static Select selectFrom(CassandraTableHandle tableHandle, List columns) { - return from(select(columns), tableHandle); + SelectFrom selectFrom = from(tableHandle); + return columns.isEmpty() ? selectFrom.all() : selectFrom.columns(selection(columns)); } - public static Select from(Selection selection, CassandraTableHandle tableHandle) + public static SelectFrom from(CassandraTableHandle tableHandle) { String schema = validSchemaName(tableHandle.getSchemaName()); String table = validTableName(tableHandle.getTableName()); - return selection.from(schema, table); + return QueryBuilder.selectFrom(schema, table); } public static Select selectDistinctFrom(CassandraTableHandle tableHandle, List columns) { - return from(select(columns).distinct(), tableHandle); + SelectFrom selectFrom = from(tableHandle).distinct(); + if (columns.isEmpty()) { + return selectFrom.all(); + } + return selectFrom.columns(selection(columns)); } private static String getWhereCondition(String partition, String clusteringKeyPredicates) diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/HostAddressFactory.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/HostAddressFactory.java index f861150cec73..f4eb8cbaab35 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/HostAddressFactory.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/util/HostAddressFactory.java @@ -13,29 +13,43 @@ */ package io.trino.plugin.cassandra.util; -import com.datastax.driver.core.Host; +import com.datastax.oss.driver.api.core.metadata.Node; import io.trino.spi.HostAddress; +import io.trino.spi.TrinoException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.lang.String.format; + public class HostAddressFactory { private final Map hostMap = new HashMap<>(); - public HostAddress toHostAddress(Host host) + public HostAddress toHostAddress(Node node) { - return toHostAddress(host.getAddress().getHostAddress()); + SocketAddress address = node.getEndPoint().resolve(); + if (address instanceof InetSocketAddress) { + return toHostAddress(((InetSocketAddress) address).getAddress().getHostAddress()); + } + throw new TrinoException( + GENERIC_INTERNAL_ERROR, + format( + "Only endpoints which resolve to a InetSocketAddress are supported. Resolving to socket addresses of type %s is not supported", + address.getClass().getName())); } - public List toHostAddressList(Collection hosts) + public List toHostAddressList(Collection nodes) { - ArrayList list = new ArrayList<>(hosts.size()); - for (Host host : hosts) { - list.add(toHostAddress(host)); + ArrayList list = new ArrayList<>(nodes.size()); + for (Node node : nodes) { + list.add(toHostAddress(node)); } return list; } diff --git a/plugin/trino-cassandra/src/test/java/com/datastax/driver/core/TestHost.java b/plugin/trino-cassandra/src/test/java/com/datastax/driver/core/TestHost.java deleted file mode 100644 index 4eccfd9e0f6e..000000000000 --- a/plugin/trino-cassandra/src/test/java/com/datastax/driver/core/TestHost.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.datastax.driver.core; - -import java.net.InetSocketAddress; - -public class TestHost - extends Host -{ - public TestHost(InetSocketAddress address) - { - super(address, new ConvictionPolicy.DefaultConvictionPolicy.Factory(), Cluster.builder().addContactPoints("localhost").build().manager); - } -} diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/BaseCassandraConnectorSmokeTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/BaseCassandraConnectorSmokeTest.java index afa2489f59ba..e8a5a27455f0 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/BaseCassandraConnectorSmokeTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/BaseCassandraConnectorSmokeTest.java @@ -15,6 +15,7 @@ import io.trino.testing.BaseConnectorSmokeTest; import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.TestTable; import org.testng.annotations.Test; import java.time.ZoneId; @@ -22,6 +23,7 @@ import static io.trino.plugin.cassandra.CassandraTestingUtils.TABLE_DELETE_DATA; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; public abstract class BaseCassandraConnectorSmokeTest @@ -80,4 +82,13 @@ public void testRowLevelDelete() assertQuery("SELECT COUNT(*) FROM " + keyspaceAndTable, "VALUES 14"); } + + @Test + public void testInsertDate() + { + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_insert_", "(a_date date)")) { + assertUpdate("INSERT INTO " + table.getName() + " (a_date) VALUES ( DATE '2020-05-11')", 1); + assertThat(query("SELECT a_date FROM " + table.getName())).matches("VALUES (DATE '2020-05-11')"); + } + } } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraQueryRunner.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraQueryRunner.java index fe2f1417f826..758feddd9c93 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraQueryRunner.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraQueryRunner.java @@ -21,6 +21,7 @@ import io.trino.testing.DistributedQueryRunner; import io.trino.tpch.TpchTable; +import java.util.HashMap; import java.util.Map; import static io.trino.plugin.cassandra.CassandraTestingUtils.createKeyspace; @@ -35,12 +36,13 @@ private CassandraQueryRunner() {} public static DistributedQueryRunner createCassandraQueryRunner(CassandraServer server, TpchTable... tables) throws Exception { - return createCassandraQueryRunner(server, ImmutableMap.of(), ImmutableList.copyOf(tables)); + return createCassandraQueryRunner(server, ImmutableMap.of(), ImmutableMap.of(), ImmutableList.copyOf(tables)); } public static DistributedQueryRunner createCassandraQueryRunner( CassandraServer server, Map extraProperties, + Map connectorProperties, Iterable> tables) throws Exception { @@ -51,11 +53,15 @@ public static DistributedQueryRunner createCassandraQueryRunner( queryRunner.installPlugin(new TpchPlugin()); queryRunner.createCatalog("tpch", "tpch"); + connectorProperties = new HashMap<>(ImmutableMap.copyOf(connectorProperties)); + connectorProperties.putIfAbsent("cassandra.contact-points", server.getHost()); + connectorProperties.putIfAbsent("cassandra.native-protocol-port", Integer.toString(server.getPort())); + connectorProperties.putIfAbsent("cassandra.load-policy.use-dc-aware", "true"); + connectorProperties.putIfAbsent("cassandra.load-policy.dc-aware.local-dc", "datacenter1"); + connectorProperties.putIfAbsent("cassandra.allow-drop-table", "true"); + queryRunner.installPlugin(new CassandraPlugin()); - queryRunner.createCatalog("cassandra", "cassandra", ImmutableMap.of( - "cassandra.contact-points", server.getHost(), - "cassandra.native-protocol-port", Integer.toString(server.getPort()), - "cassandra.allow-drop-table", "true")); + queryRunner.createCatalog("cassandra", "cassandra", connectorProperties); createKeyspace(server.getSession(), "tpch"); copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, createCassandraSession("tpch"), tables); @@ -80,6 +86,7 @@ public static void main(String[] args) DistributedQueryRunner queryRunner = createCassandraQueryRunner( new CassandraServer(), ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of(), TpchTable.getTables()); Logger log = Logger.get(CassandraQueryRunner.class); diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java index 9014e2fa3378..f1cc33f78390 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java @@ -13,10 +13,13 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.Cluster; -import com.datastax.driver.core.ResultSet; -import com.datastax.driver.core.Row; -import com.google.common.collect.ImmutableList; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.config.ProgrammaticDriverConfigLoaderBuilder; +import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.Row; import com.google.common.io.Resources; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; @@ -31,7 +34,10 @@ import java.util.List; import java.util.concurrent.TimeoutException; -import static com.datastax.driver.core.ProtocolVersion.V3; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.CONTROL_CONNECTION_AGREEMENT_TIMEOUT; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.METADATA_SCHEMA_REFRESHED_KEYSPACES; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.PROTOCOL_VERSION; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.REQUEST_TIMEOUT; import static com.google.common.io.Files.write; import static com.google.common.io.Resources.getResource; import static java.lang.String.format; @@ -73,24 +79,29 @@ public CassandraServer(String cassandraVersion) .withCopyFileToContainer(forHostPath(prepareCassandraYaml()), "/etc/cassandra/cassandra.yaml"); this.dockerContainer.start(); - Cluster.Builder clusterBuilder = Cluster.builder() - .withProtocolVersion(V3) - .withClusterName("TestCluster") - .addContactPointsWithPorts(ImmutableList.of( - new InetSocketAddress(this.dockerContainer.getContainerIpAddress(), this.dockerContainer.getMappedPort(PORT)))) - .withMaxSchemaAgreementWaitSeconds(30); + ProgrammaticDriverConfigLoaderBuilder driverConfigLoaderBuilder = DriverConfigLoader.programmaticBuilder(); + driverConfigLoaderBuilder.withDuration(REQUEST_TIMEOUT, java.time.Duration.ofSeconds(12)); + driverConfigLoaderBuilder.withString(PROTOCOL_VERSION, ProtocolVersion.V3.name()); + driverConfigLoaderBuilder.withDuration(CONTROL_CONNECTION_AGREEMENT_TIMEOUT, java.time.Duration.ofSeconds(30)); + // allow the retrieval of metadata for the system keyspaces + driverConfigLoaderBuilder.withStringList(METADATA_SCHEMA_REFRESHED_KEYSPACES, List.of()); + + CqlSessionBuilder cqlSessionBuilder = CqlSession.builder() + .withApplicationName("TestCluster") + .addContactPoint(new InetSocketAddress(this.dockerContainer.getContainerIpAddress(), this.dockerContainer.getMappedPort(PORT))) + .withLocalDatacenter("datacenter1") + .withConfigLoader(driverConfigLoaderBuilder.build()); - ReopeningCluster cluster = new ReopeningCluster(clusterBuilder::build); CassandraSession session = new CassandraSession( JsonCodec.listJsonCodec(ExtraColumnMetadata.class), - cluster, + cqlSessionBuilder::build, new Duration(1, MINUTES)); try { checkConnectivity(session); } catch (RuntimeException e) { - cluster.close(); + session.close(); this.dockerContainer.stop(); throw e; } @@ -173,6 +184,9 @@ private void refreshSizeEstimates() @Override public void close() { + if (session != null) { + session.close(); + } dockerContainer.close(); } } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java index 1ba1fb198bfd..c33d56c708e5 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java @@ -13,9 +13,8 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.LocalDate; -import com.datastax.driver.core.querybuilder.Insert; -import com.datastax.driver.core.querybuilder.QueryBuilder; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.querybuilder.QueryBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -26,9 +25,13 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneId; import java.util.Date; import java.util.UUID; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal; import static java.lang.String.format; import static org.testng.Assert.assertEquals; @@ -167,28 +170,29 @@ private static void createTableUserDefinedType(CassandraSession session, SchemaT private static void insertTestData(CassandraSession session, SchemaTableName table, Date date, int rowsCount) { for (int rowNumber = 1; rowNumber <= rowsCount; rowNumber++) { - Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("key", "key " + rowNumber) - .value("typeuuid", UUID.fromString(format("00000000-0000-0000-0000-%012d", rowNumber))) - .value("typetinyint", rowNumber) - .value("typesmallint", rowNumber) - .value("typeinteger", rowNumber) - .value("typelong", rowNumber + 1000) - .value("typebytes", ByteBuffer.wrap(Ints.toByteArray(rowNumber)).asReadOnlyBuffer()) - .value("typedate", LocalDate.fromMillisSinceEpoch(date.getTime())) - .value("typetimestamp", date) - .value("typeansi", "ansi " + rowNumber) - .value("typeboolean", rowNumber % 2 == 0) - .value("typedecimal", new BigDecimal(Math.pow(2, rowNumber))) - .value("typedouble", Math.pow(4, rowNumber)) - .value("typefloat", (float) Math.pow(8, rowNumber)) - .value("typeinet", InetAddresses.forString("127.0.0.1")) - .value("typevarchar", "varchar " + rowNumber) - .value("typevarint", BigInteger.TEN.pow(rowNumber)) - .value("typetimeuuid", UUID.fromString(format("d2177dd0-eaa2-11de-a572-001b779c76e%d", rowNumber))) - .value("typelist", ImmutableList.of("list-value-1" + rowNumber, "list-value-2" + rowNumber)) - .value("typemap", ImmutableMap.of(rowNumber, rowNumber + 1L, rowNumber + 2, rowNumber + 3L)) - .value("typeset", ImmutableSet.of(false, true)); + SimpleStatement insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) + .value("key", literal("key " + rowNumber)) + .value("typeuuid", literal(UUID.fromString(format("00000000-0000-0000-0000-%012d", rowNumber)))) + .value("typetinyint", literal(rowNumber)) + .value("typesmallint", literal(rowNumber)) + .value("typeinteger", literal(rowNumber)) + .value("typelong", literal(rowNumber + 1000)) + .value("typebytes", literal(ByteBuffer.wrap(Ints.toByteArray(rowNumber)).asReadOnlyBuffer())) + .value("typedate", literal(LocalDate.ofInstant(Instant.ofEpochMilli(date.getTime()), ZoneId.systemDefault()))) + .value("typetimestamp", literal(Instant.ofEpochMilli(date.getTime()))) + .value("typeansi", literal("ansi " + rowNumber)) + .value("typeboolean", literal(rowNumber % 2 == 0)) + .value("typedecimal", literal(new BigDecimal(Math.pow(2, rowNumber)))) + .value("typedouble", literal(Math.pow(4, rowNumber))) + .value("typefloat", literal((float) Math.pow(8, rowNumber))) + .value("typeinet", literal(InetAddresses.forString("127.0.0.1"))) + .value("typevarchar", literal("varchar " + rowNumber)) + .value("typevarint", literal(BigInteger.TEN.pow(rowNumber))) + .value("typetimeuuid", literal(UUID.fromString(format("d2177dd0-eaa2-11de-a572-001b779c76e%d", rowNumber)))) + .value("typelist", literal(ImmutableList.of("list-value-1" + rowNumber, "list-value-2" + rowNumber))) + .value("typemap", literal(ImmutableMap.of(rowNumber, rowNumber + 1L, rowNumber + 2, rowNumber + 3L))) + .value("typeset", literal(ImmutableSet.of(false, true))) + .build(); session.execute(insert); } @@ -230,26 +234,27 @@ private static void insertIntoTableDeleteData(CassandraSession session, SchemaTa 9 | 9 | clust_one_9 | null */ for (int rowNumber = 1; rowNumber < 10; rowNumber++) { - Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", rowNumber) - .value("partition_two", rowNumber) - .value("clust_one", "clust_one_" + rowNumber); + SimpleStatement insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) + .value("partition_one", literal(rowNumber)) + .value("partition_two", literal(rowNumber)) + .value("clust_one", literal("clust_one_" + rowNumber)) + .build(); session.execute(insert); } session.execute(QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", 1L).value("partition_two", 1).value("clust_one", "clust_one_" + 2)); + .value("partition_one", literal(1L)).value("partition_two", literal(1)).value("clust_one", literal("clust_one_" + 2)).build()); session.execute(QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", 1L).value("partition_two", 1).value("clust_one", "clust_one_" + 3)); + .value("partition_one", literal(1L)).value("partition_two", literal(1)).value("clust_one", literal("clust_one_" + 3)).build()); session.execute(QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", 1L).value("partition_two", 2).value("clust_one", "clust_one_" + 1)); + .value("partition_one", literal(1L)).value("partition_two", literal(2)).value("clust_one", literal("clust_one_" + 1)).build()); session.execute(QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", 1L).value("partition_two", 2).value("clust_one", "clust_one_" + 2)); + .value("partition_one", literal(1L)).value("partition_two", literal(2)).value("clust_one", literal("clust_one_" + 2)).build()); session.execute(QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", 1L).value("partition_two", 2).value("clust_one", "clust_one_" + 3)); + .value("partition_one", literal(1L)).value("partition_two", literal(2)).value("clust_one", literal("clust_one_" + 3)).build()); session.execute(QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) - .value("partition_one", 2L).value("partition_two", 2).value("clust_one", "clust_one_" + 1)); + .value("partition_one", literal(2L)).value("partition_two", literal(2)).value("clust_one", literal("clust_one_" + 1)).build()); assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), 15); } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/ScyllaQueryRunner.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/ScyllaQueryRunner.java index 30c00d42e3c6..8091bcd1ea1f 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/ScyllaQueryRunner.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/ScyllaQueryRunner.java @@ -53,6 +53,8 @@ public static DistributedQueryRunner createScyllaQueryRunner( connectorProperties.putIfAbsent("cassandra.contact-points", server.getHost()); connectorProperties.putIfAbsent("cassandra.native-protocol-port", Integer.toString(server.getPort())); connectorProperties.putIfAbsent("cassandra.allow-drop-table", "true"); + connectorProperties.putIfAbsent("cassandra.load-policy.use-dc-aware", "true"); + connectorProperties.putIfAbsent("cassandra.load-policy.dc-aware.local-dc", "datacenter1"); queryRunner.installPlugin(new CassandraPlugin()); queryRunner.createCatalog("cassandra", "cassandra", connectorProperties); diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraClientConfig.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraClientConfig.java index df3b3706fd4c..6b76353cd42c 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraClientConfig.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraClientConfig.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.ConsistencyLevel; -import com.datastax.driver.core.SocketOptions; +import com.datastax.oss.driver.api.core.DefaultConsistencyLevel; +import com.datastax.oss.driver.api.core.DefaultProtocolVersion; import com.google.common.collect.ImmutableMap; import io.airlift.units.Duration; import org.testng.annotations.Test; @@ -24,7 +24,6 @@ import java.nio.file.Path; import java.util.Map; -import static com.datastax.driver.core.ProtocolVersion.V2; import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; @@ -39,7 +38,7 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(CassandraClientConfig.class) .setFetchSize(5_000) - .setConsistencyLevel(ConsistencyLevel.ONE) + .setConsistencyLevel(DefaultConsistencyLevel.ONE) .setContactPoints("") .setNativeProtocolPort(9042) .setPartitionSizeForBatchSelect(100) @@ -49,8 +48,8 @@ public void testDefaults() .setAllowDropTable(false) .setUsername(null) .setPassword(null) - .setClientReadTimeout(new Duration(SocketOptions.DEFAULT_READ_TIMEOUT_MILLIS, MILLISECONDS)) - .setClientConnectTimeout(new Duration(SocketOptions.DEFAULT_CONNECT_TIMEOUT_MILLIS, MILLISECONDS)) + .setClientReadTimeout(new Duration(12_000, MILLISECONDS)) + .setClientConnectTimeout(new Duration(5_000, MILLISECONDS)) .setClientSoLinger(null) .setRetryPolicy(RetryPolicyType.DEFAULT) .setUseDCAware(false) @@ -93,7 +92,7 @@ public void testExplicitPropertyMappings() .put("cassandra.client.read-timeout", "11ms") .put("cassandra.client.connect-timeout", "22ms") .put("cassandra.client.so-linger", "33") - .put("cassandra.retry-policy", "BACKOFF") + .put("cassandra.retry-policy", "DOWNGRADING_CONSISTENCY") .put("cassandra.load-policy.use-dc-aware", "true") .put("cassandra.load-policy.dc-aware.local-dc", "dc1") .put("cassandra.load-policy.dc-aware.used-hosts-per-remote-dc", "1") @@ -104,7 +103,7 @@ public void testExplicitPropertyMappings() .put("cassandra.no-host-available-retry-timeout", "3m") .put("cassandra.speculative-execution.limit", "10") .put("cassandra.speculative-execution.delay", "101s") - .put("cassandra.protocol-version", "V2") + .put("cassandra.protocol-version", "V3") .put("cassandra.tls.enabled", "true") .put("cassandra.tls.keystore-path", keystoreFile.toString()) .put("cassandra.tls.keystore-password", "keystore-password") @@ -116,7 +115,7 @@ public void testExplicitPropertyMappings() .setContactPoints("host1", "host2") .setNativeProtocolPort(9999) .setFetchSize(10_000) - .setConsistencyLevel(ConsistencyLevel.TWO) + .setConsistencyLevel(DefaultConsistencyLevel.TWO) .setPartitionSizeForBatchSelect(77) .setSplitSize(1_025) .setBatchSize(999) @@ -127,7 +126,7 @@ public void testExplicitPropertyMappings() .setClientReadTimeout(new Duration(11, MILLISECONDS)) .setClientConnectTimeout(new Duration(22, MILLISECONDS)) .setClientSoLinger(33) - .setRetryPolicy(RetryPolicyType.BACKOFF) + .setRetryPolicy(RetryPolicyType.DOWNGRADING_CONSISTENCY) .setUseDCAware(true) .setDcAwareLocalDC("dc1") .setDcAwareUsedHostsPerRemoteDc(1) @@ -138,7 +137,7 @@ public void testExplicitPropertyMappings() .setNoHostAvailableRetryTimeout(new Duration(3, MINUTES)) .setSpeculativeExecutionLimit(10) .setSpeculativeExecutionDelay(new Duration(101, SECONDS)) - .setProtocolVersion(V2) + .setProtocolVersion(DefaultProtocolVersion.V3) .setTlsEnabled(true) .setKeystorePath(keystoreFile.toFile()) .setKeystorePassword("keystore-password") diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java index aa9d9cd2e8f0..5d35ceedee8b 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java @@ -13,7 +13,7 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.utils.Bytes; +import com.datastax.oss.protocol.internal.util.Bytes; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Shorts; @@ -120,6 +120,8 @@ public void setup() Connector connector = connectorFactory.create("test", ImmutableMap.of( "cassandra.contact-points", server.getHost(), + "cassandra.load-policy.use-dc-aware", "true", + "cassandra.load-policy.dc-aware.local-dc", "datacenter1", "cassandra.native-protocol-port", Integer.toString(server.getPort())), new TestingConnectorContext()); diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java index d9a8823ef36e..8ec93a8d2707 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java @@ -40,8 +40,8 @@ import java.util.List; import java.util.Optional; -import static com.datastax.driver.core.utils.Bytes.toHexString; -import static com.datastax.driver.core.utils.Bytes.toRawHexString; +import static com.datastax.oss.driver.api.core.data.ByteUtils.toHexString; +import static com.google.common.io.BaseEncoding.base16; import static io.trino.plugin.cassandra.CassandraQueryRunner.createCassandraQueryRunner; import static io.trino.plugin.cassandra.CassandraQueryRunner.createCassandraSession; import static io.trino.plugin.cassandra.TestCassandraTable.clusterColumn; @@ -132,7 +132,7 @@ protected QueryRunner createQueryRunner() server = closeAfterClass(new CassandraServer()); session = server.getSession(); session.execute("CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE + " WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - return createCassandraQueryRunner(server, ImmutableMap.of(), REQUIRED_TPCH_TABLES); + return createCassandraQueryRunner(server, ImmutableMap.of(), ImmutableMap.of(), REQUIRED_TPCH_TABLES); } @Override @@ -394,7 +394,7 @@ public void testPartitionKeyPredicate() " AND typesmallint = 7" + " AND typeinteger = 7" + " AND typelong = 1007" + - " AND typebytes = from_hex('" + toRawHexString(ByteBuffer.wrap(Ints.toByteArray(7))) + "')" + + " AND typebytes = from_hex('" + base16().encode(Ints.toByteArray(7)) + "')" + " AND typedate = DATE '1970-01-01'" + " AND typetimestamp = TIMESTAMP '1970-01-01 03:04:05Z'" + " AND typeansi = 'ansi 7'" + @@ -467,7 +467,7 @@ public void testSelect() rowNumber -> String.valueOf(rowNumber), rowNumber -> String.valueOf(rowNumber), rowNumber -> String.valueOf(rowNumber + 1000), - rowNumber -> toHexString(ByteBuffer.wrap(Ints.toByteArray(rowNumber))), + rowNumber -> toHexString(ByteBuffer.wrap(Ints.toByteArray(rowNumber)).asReadOnlyBuffer()), rowNumber -> format("'%s'", DateTimeFormatter.ofPattern("uuuu-MM-dd").format(TIMESTAMP_VALUE)), rowNumber -> format("'%s'", DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSZ").format(TIMESTAMP_VALUE)), rowNumber -> format("'ansi %d'", rowNumber), diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraLatestConnectorSmokeTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraLatestConnectorSmokeTest.java index a67fa75a4491..2109ce2cc262 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraLatestConnectorSmokeTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraLatestConnectorSmokeTest.java @@ -31,6 +31,6 @@ protected QueryRunner createQueryRunner() CassandraServer server = closeAfterClass(new CassandraServer("3.11.10")); CassandraSession session = server.getSession(); createTestTables(session, KEYSPACE, Timestamp.from(TIMESTAMP_VALUE.toInstant())); - return createCassandraQueryRunner(server, ImmutableMap.of(), REQUIRED_TPCH_TABLES); + return createCassandraQueryRunner(server, ImmutableMap.of(), ImmutableMap.of(), REQUIRED_TPCH_TABLES); } } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraProtocolVersionV3ConnectorSmokeTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraProtocolVersionV3ConnectorSmokeTest.java new file mode 100644 index 000000000000..7dd1953b3ab7 --- /dev/null +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraProtocolVersionV3ConnectorSmokeTest.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.cassandra; + +import com.google.common.collect.ImmutableMap; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.testng.annotations.Test; + +import java.sql.Timestamp; + +import static io.trino.plugin.cassandra.CassandraQueryRunner.createCassandraQueryRunner; +import static io.trino.plugin.cassandra.CassandraTestingUtils.createTestTables; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCassandraProtocolVersionV3ConnectorSmokeTest + extends BaseCassandraConnectorSmokeTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + CassandraServer server = closeAfterClass(new CassandraServer()); + CassandraSession session = server.getSession(); + createTestTables(session, KEYSPACE, Timestamp.from(TIMESTAMP_VALUE.toInstant())); + return createCassandraQueryRunner(server, ImmutableMap.of(), ImmutableMap.of("cassandra.protocol-version", "V3"), REQUIRED_TPCH_TABLES); + } + + @Test + @Override + public void testInsertDate() + { + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_insert_", "(a_date date)")) { + assertUpdate("INSERT INTO " + table.getName() + " (a_date) VALUES ('2020-05-11')", 1); + assertThat(query("SELECT a_date FROM " + table.getName())).matches("VALUES (CAST('2020-05-11' AS varchar))"); + } + } +} diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java index b94bf1c6aaee..099c3aad9946 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java @@ -13,7 +13,7 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.DataType; +import com.datastax.oss.driver.api.core.type.DataTypes; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.databind.ObjectMapper; @@ -29,13 +29,13 @@ public class TestCassandraType @Test public void testJsonArrayEncoding() { - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList("one", "two", "three\""), DataType.varchar()))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(1, 2, 3), DataType.cint()))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(100000L, 200000000L, 3000000000L), DataType.bigint()))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(1.0, 2.0, 3.0), DataType.cdouble()))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList((short) -32768, (short) 0, (short) 32767), DataType.smallint()))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList((byte) -128, (byte) 0, (byte) 127), DataType.tinyint()))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList("1970-01-01", "5555-06-15", "9999-12-31"), DataType.date()))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList("one", "two", "three\""), DataTypes.TEXT))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(1, 2, 3), DataTypes.INT))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(100000L, 200000000L, 3000000000L), DataTypes.BIGINT))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(1.0, 2.0, 3.0), DataTypes.DOUBLE))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList((short) -32768, (short) 0, (short) 32767), DataTypes.SMALLINT))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList((byte) -128, (byte) 0, (byte) 127), DataTypes.TINYINT))); + assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList("1970-01-01", "5555-06-15", "9999-12-31"), DataTypes.DATE))); } private static void continueWhileNotNull(JsonParser parser, JsonToken token) diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java index e21501df6a7b..480233200942 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java @@ -140,6 +140,7 @@ protected QueryRunner createQueryRunner() return createCassandraQueryRunner( server, ImmutableMap.of(), + ImmutableMap.of(), ImmutableList.of()); } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestMurmur3PartitionerTokenRing.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestMurmur3PartitionerTokenRing.java index da4de7a83a58..9c252a727d44 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestMurmur3PartitionerTokenRing.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestMurmur3PartitionerTokenRing.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.cassandra; +import com.datastax.oss.driver.internal.core.metadata.token.Murmur3Token; import org.testng.annotations.Test; import java.math.BigInteger; @@ -28,24 +29,24 @@ public class TestMurmur3PartitionerTokenRing @Test public void testGetTokenCountInRange() { - assertEquals(tokenRing.getTokenCountInRange("0", "1"), ONE); - assertEquals(tokenRing.getTokenCountInRange("-1", "1"), new BigInteger("2")); - assertEquals(tokenRing.getTokenCountInRange("-100", "100"), new BigInteger("200")); - assertEquals(tokenRing.getTokenCountInRange("0", "10"), new BigInteger("10")); - assertEquals(tokenRing.getTokenCountInRange("1", "11"), new BigInteger("10")); - assertEquals(tokenRing.getTokenCountInRange("0", "0"), ZERO); - assertEquals(tokenRing.getTokenCountInRange("1", "1"), ZERO); - assertEquals(tokenRing.getTokenCountInRange(Long.toString(Long.MIN_VALUE), Long.toString(Long.MIN_VALUE)), BigInteger.valueOf(2).pow(64).subtract(ONE)); - assertEquals(tokenRing.getTokenCountInRange("1", "0"), BigInteger.valueOf(2).pow(64).subtract(BigInteger.valueOf(2))); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(0), new Murmur3Token(1)), ONE); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(-1), new Murmur3Token(1)), new BigInteger("2")); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(-100), new Murmur3Token(100)), new BigInteger("200")); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(0), new Murmur3Token(10)), new BigInteger("10")); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(1), new Murmur3Token(11)), new BigInteger("10")); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(0), new Murmur3Token(0)), ZERO); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(1), new Murmur3Token(1)), ZERO); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(Long.MIN_VALUE), new Murmur3Token(Long.MIN_VALUE)), BigInteger.valueOf(2).pow(64).subtract(ONE)); + assertEquals(tokenRing.getTokenCountInRange(new Murmur3Token(1), new Murmur3Token(0)), BigInteger.valueOf(2).pow(64).subtract(BigInteger.valueOf(2))); } @Test public void testGetRingFraction() { - assertEquals(tokenRing.getRingFraction("1", "1"), 0.0, 0.001); - assertEquals(tokenRing.getRingFraction("1", "0"), 1.0, 0.001); - assertEquals(tokenRing.getRingFraction("0", Long.toString(Long.MAX_VALUE)), 0.5, 0.001); - assertEquals(tokenRing.getRingFraction(Long.toString(Long.MIN_VALUE), Long.toString(Long.MAX_VALUE)), 1.0, 0.001); - assertEquals(tokenRing.getRingFraction(Long.toString(Long.MIN_VALUE), Long.toString(Long.MIN_VALUE)), 1.0, 0.001); + assertEquals(tokenRing.getRingFraction(new Murmur3Token(1), new Murmur3Token(1)), 0.0, 0.001); + assertEquals(tokenRing.getRingFraction(new Murmur3Token(1), new Murmur3Token(0)), 1.0, 0.001); + assertEquals(tokenRing.getRingFraction(new Murmur3Token(0), new Murmur3Token(Long.MAX_VALUE)), 0.5, 0.001); + assertEquals(tokenRing.getRingFraction(new Murmur3Token(Long.MIN_VALUE), new Murmur3Token(Long.MAX_VALUE)), 1.0, 0.001); + assertEquals(tokenRing.getRingFraction(new Murmur3Token(Long.MIN_VALUE), new Murmur3Token(Long.MIN_VALUE)), 1.0, 0.001); } } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestRandomPartitionerTokenRing.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestRandomPartitionerTokenRing.java index 5bf74141dbd6..4776da96f169 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestRandomPartitionerTokenRing.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestRandomPartitionerTokenRing.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.cassandra; +import com.datastax.oss.driver.internal.core.metadata.token.RandomToken; import org.testng.annotations.Test; import java.math.BigInteger; @@ -28,24 +29,34 @@ public class TestRandomPartitionerTokenRing @Test public void testGetRingFraction() { - assertEquals(tokenRing.getTokenCountInRange("0", "1"), ONE); - assertEquals(tokenRing.getTokenCountInRange("0", "200"), new BigInteger("200")); - assertEquals(tokenRing.getTokenCountInRange("0", "10"), new BigInteger("10")); - assertEquals(tokenRing.getTokenCountInRange("1", "11"), new BigInteger("10")); - assertEquals(tokenRing.getTokenCountInRange("0", "0"), ZERO); - assertEquals(tokenRing.getTokenCountInRange("-1", "-1"), BigInteger.valueOf(2).pow(127).add(ONE)); - assertEquals(tokenRing.getTokenCountInRange("1", "0"), BigInteger.valueOf(2).pow(127)); + assertEquals(tokenRing.getTokenCountInRange(randomToken(0), randomToken(1)), ONE); + assertEquals(tokenRing.getTokenCountInRange(randomToken(0), randomToken(200)), new BigInteger("200")); + assertEquals(tokenRing.getTokenCountInRange(randomToken(0), randomToken(10)), new BigInteger("10")); + assertEquals(tokenRing.getTokenCountInRange(randomToken(1), randomToken(11)), new BigInteger("10")); + assertEquals(tokenRing.getTokenCountInRange(randomToken(0), randomToken(0)), ZERO); + assertEquals(tokenRing.getTokenCountInRange(randomToken(-1), randomToken(-1)), BigInteger.valueOf(2).pow(127).add(ONE)); + assertEquals(tokenRing.getTokenCountInRange(randomToken(1), randomToken(0)), BigInteger.valueOf(2).pow(127)); } @Test public void testGetTokenCountInRange() { - assertEquals(tokenRing.getRingFraction("0", "0"), 0.0, 0.001); - assertEquals(tokenRing.getRingFraction("1", "0"), 1.0, 0.001); - assertEquals(tokenRing.getRingFraction("-1", "-1"), 1.0, 0.001); - assertEquals(tokenRing.getRingFraction("0", BigInteger.valueOf(2).pow(126).toString()), 0.5, 0.001); - assertEquals(tokenRing.getRingFraction(BigInteger.valueOf(2).pow(126).toString(), BigInteger.valueOf(2).pow(127).toString()), 0.5, 0.001); - assertEquals(tokenRing.getRingFraction("0", BigInteger.valueOf(2).pow(126).toString()), 0.5, 0.001); - assertEquals(tokenRing.getRingFraction("0", BigInteger.valueOf(2).pow(127).toString()), 1.0, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(0), randomToken(0)), 0.0, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(1), randomToken(0)), 1.0, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(-1), randomToken(-1)), 1.0, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(0), randomToken(BigInteger.valueOf(2).pow(126))), 0.5, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(BigInteger.valueOf(2).pow(126)), randomToken(BigInteger.valueOf(2).pow(127))), 0.5, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(0), randomToken(BigInteger.valueOf(2).pow(126))), 0.5, 0.001); + assertEquals(tokenRing.getRingFraction(randomToken(0), randomToken(BigInteger.valueOf(2).pow(127))), 1.0, 0.001); + } + + private static RandomToken randomToken(long value) + { + return randomToken(BigInteger.valueOf(value)); + } + + private static RandomToken randomToken(BigInteger value) + { + return new RandomToken(value); } } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java index 758b48df9b4a..73da7463bcec 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java @@ -13,8 +13,11 @@ */ package io.trino.plugin.cassandra; -import com.datastax.driver.core.Cluster; -import com.google.common.collect.ImmutableList; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.config.ProgrammaticDriverConfigLoaderBuilder; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.units.Duration; @@ -25,7 +28,10 @@ import java.util.List; import java.util.concurrent.TimeoutException; -import static com.datastax.driver.core.ProtocolVersion.V3; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.CONTROL_CONNECTION_AGREEMENT_TIMEOUT; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.METADATA_SCHEMA_REFRESHED_KEYSPACES; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.PROTOCOL_VERSION; +import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.REQUEST_TIMEOUT; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; @@ -45,27 +51,35 @@ public class TestingScyllaServer private final CassandraSession session; public TestingScyllaServer() + throws Exception { this("2.2.0"); } public TestingScyllaServer(String version) + throws Exception { container = new GenericContainer<>("scylladb/scylla:" + version) .withCommand("--smp", "1") // Limit SMP to run in a machine having many cores https://github.com/scylladb/scylla/issues/5638 .withExposedPorts(PORT); container.start(); - Cluster.Builder clusterBuilder = Cluster.builder() - .withProtocolVersion(V3) - .withClusterName("TestCluster") - .addContactPointsWithPorts(ImmutableList.of( - new InetSocketAddress(container.getContainerIpAddress(), container.getMappedPort(PORT)))) - .withMaxSchemaAgreementWaitSeconds(60); + ProgrammaticDriverConfigLoaderBuilder config = DriverConfigLoader.programmaticBuilder(); + config.withDuration(REQUEST_TIMEOUT, java.time.Duration.ofSeconds(12)); + config.withString(PROTOCOL_VERSION, ProtocolVersion.V3.name()); + config.withDuration(CONTROL_CONNECTION_AGREEMENT_TIMEOUT, java.time.Duration.ofSeconds(30)); + // allow the retrieval of metadata for the system keyspaces + config.withStringList(METADATA_SCHEMA_REFRESHED_KEYSPACES, List.of()); + + CqlSessionBuilder cqlSessionBuilder = CqlSession.builder() + .withApplicationName("TestCluster") + .addContactPoint(new InetSocketAddress(this.container.getContainerIpAddress(), this.container.getMappedPort(PORT))) + .withLocalDatacenter("datacenter1") + .withConfigLoader(config.build()); session = new CassandraSession( JsonCodec.listJsonCodec(ExtraColumnMetadata.class), - new ReopeningCluster(clusterBuilder::build), + cqlSessionBuilder::build, new Duration(1, MINUTES)); } @@ -117,6 +131,9 @@ private void refreshSizeEstimates() @Override public void close() { + if (session != null) { + session.close(); + } container.close(); } } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java index 386dfa1d8d7f..a986a5fa51dd 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java @@ -13,7 +13,7 @@ */ package io.trino.plugin.cassandra.util; -import com.datastax.driver.core.VersionNumber; +import com.datastax.oss.driver.api.core.Version; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.plugin.cassandra.CassandraClusteringPredicatesExtractor; @@ -37,7 +37,7 @@ public class TestCassandraClusteringPredicatesExtractor private static CassandraColumnHandle col3; private static CassandraColumnHandle col4; private static CassandraTable cassandraTable; - private static VersionNumber cassandraVersion; + private static Version cassandraVersion; @BeforeTest public void setUp() @@ -50,7 +50,7 @@ public void setUp() cassandraTable = new CassandraTable( new CassandraTableHandle("test", "records"), ImmutableList.of(col1, col2, col3, col4)); - cassandraVersion = VersionNumber.parse("2.1.5"); + cassandraVersion = Version.parse("2.1.5"); } @Test diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestHostAddressFactory.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestHostAddressFactory.java index 2422e56ca6bb..8a52fd6d4b33 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestHostAddressFactory.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestHostAddressFactory.java @@ -13,8 +13,13 @@ */ package io.trino.plugin.cassandra.util; -import com.datastax.driver.core.Host; -import com.datastax.driver.core.TestHost; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.session.ProgrammaticArguments; +import com.datastax.oss.driver.internal.core.context.DefaultDriverContext; +import com.datastax.oss.driver.internal.core.context.InternalDriverContext; +import com.datastax.oss.driver.internal.core.metadata.DefaultEndPoint; +import com.datastax.oss.driver.internal.core.metadata.DefaultNode; import com.google.common.collect.ImmutableSet; import io.trino.spi.HostAddress; import org.testng.annotations.Test; @@ -32,17 +37,26 @@ public class TestHostAddressFactory public void testToHostAddressList() throws Exception { - Set hosts = ImmutableSet.of( - new TestHost( - new InetSocketAddress( - InetAddress.getByAddress(new byte[] { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 - }), - 3000)), - new TestHost(new InetSocketAddress(InetAddress.getByAddress(new byte[] {1, 2, 3, 4}), 3000))); + DriverConfigLoader driverConfigLoader = DriverConfigLoader.programmaticBuilder().build(); + ProgrammaticArguments args = + ProgrammaticArguments.builder().build(); + InternalDriverContext driverContext = new DefaultDriverContext(driverConfigLoader, args); + + Set nodes = ImmutableSet.of( + new DefaultNode( + new DefaultEndPoint( + new InetSocketAddress( + InetAddress.getByAddress(new byte[] { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + }), + 3000)), + driverContext), + new DefaultNode( + new DefaultEndPoint(new InetSocketAddress(InetAddress.getByAddress(new byte[] {1, 2, 3, 4}), 3000)), + driverContext)); HostAddressFactory hostAddressFactory = new HostAddressFactory(); - List list = hostAddressFactory.toHostAddressList(hosts); + List list = hostAddressFactory.toHostAddressList(nodes); assertEquals(list.toString(), "[[102:304:506:708:90a:b0c:d0e:f10], 1.2.3.4]"); } diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/singlenode-cassandra/cassandra.properties b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/singlenode-cassandra/cassandra.properties index d9d63db5e763..40cbca6ba00b 100644 --- a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/singlenode-cassandra/cassandra.properties +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/singlenode-cassandra/cassandra.properties @@ -1,3 +1,5 @@ connector.name=cassandra cassandra.contact-points=cassandra cassandra.allow-drop-table=true +cassandra.load-policy.use-dc-aware=true +cassandra.load-policy.dc-aware.local-dc=datacenter1