diff --git a/docs/src/main/sphinx/connector/cassandra.rst b/docs/src/main/sphinx/connector/cassandra.rst index b3a5bbdd2a36..1980d3293466 100644 --- a/docs/src/main/sphinx/connector/cassandra.rst +++ b/docs/src/main/sphinx/connector/cassandra.rst @@ -214,7 +214,7 @@ DATE DATE DECIMAL DOUBLE DOUBLE DOUBLE FLOAT REAL -INET VARCHAR(45) +INET IPADDRESS INT INTEGER LIST VARCHAR MAP VARCHAR diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java index f26e53f4d74f..3d3bed74a9a5 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java @@ -77,6 +77,7 @@ public void configure(Binder binder) binder.bind(CassandraPageSinkProvider.class).in(Scopes.SINGLETON); binder.bind(CassandraPartitionManager.class).in(Scopes.SINGLETON); binder.bind(CassandraSessionProperties.class).in(Scopes.SINGLETON); + binder.bind(CassandraTypeManager.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(CassandraClientConfig.class); @@ -105,7 +106,7 @@ protected Type _deserialize(String value, DeserializationContext context) @Singleton @Provides - public static CassandraSession createCassandraSession(CassandraClientConfig config, JsonCodec> extraColumnMetadataCodec) + public static CassandraSession createCassandraSession(CassandraTypeManager cassandraTypeManager, CassandraClientConfig config, JsonCodec> extraColumnMetadataCodec) { requireNonNull(config, "config is null"); requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); @@ -168,6 +169,7 @@ public static CassandraSession createCassandraSession(CassandraClientConfig conf cqlSessionBuilder.withConfigLoader(driverConfigLoaderBuilder.build()); return new CassandraSession( + cassandraTypeManager, extraColumnMetadataCodec, () -> { contactPoints.forEach(contactPoint -> cqlSessionBuilder.addContactPoint( diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java index 6c062539c0fa..2e3e5db9c98f 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java @@ -33,11 +33,13 @@ public class CassandraClusteringPredicatesExtractor { + private final CassandraTypeManager cassandraTypeManager; private final ClusteringPushDownResult clusteringPushDownResult; private final TupleDomain predicates; - public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates, Version cassandraVersion) + public CassandraClusteringPredicatesExtractor(CassandraTypeManager cassandraTypeManager, List clusteringColumns, TupleDomain predicates, Version cassandraVersion) { + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); this.predicates = requireNonNull(predicates, "predicates is null"); this.clusteringPushDownResult = getClusteringKeysSet(clusteringColumns, predicates, requireNonNull(cassandraVersion, "cassandraVersion is null")); } @@ -52,7 +54,7 @@ public TupleDomain getUnenforcedConstraints() return predicates.filter(((columnHandle, domain) -> !clusteringPushDownResult.hasBeenFullyPushed(columnHandle))); } - private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, Version cassandraVersion) + private ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, Version cassandraVersion) { ImmutableSet.Builder fullyPushedColumnPredicates = ImmutableSet.builder(); ImmutableList.Builder clusteringColumnSql = ImmutableList.builder(); @@ -101,7 +103,7 @@ private static ClusteringPushDownResult getClusteringKeysSet(List cassandraTypeManager.toCqlLiteral(columnHandle.getCassandraType(), value)) .collect(joining(",")); fullyPushedColumnPredicates.add(columnHandle); return CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + inValues + " )"; @@ -132,12 +134,12 @@ private static boolean isInExpressionNotAllowed(List clus return cassandraVersion.compareTo(Version.parse("2.2.0")) < 0 && currentlyProcessedClusteringColumn != (clusteringColumns.size() - 1); } - private static String toCqlLiteral(CassandraColumnHandle columnHandle, Object value) + private String toCqlLiteral(CassandraColumnHandle columnHandle, Object value) { - return columnHandle.getCassandraType().toCqlLiteral(value); + return cassandraTypeManager.toCqlLiteral(columnHandle.getCassandraType(), value); } - private static String translateRangeIntoCql(CassandraColumnHandle columnHandle, Range range) + private String translateRangeIntoCql(CassandraColumnHandle columnHandle, Range range) { if (columnHandle.getCassandraType().getKind() == CassandraType.Kind.TUPLE || columnHandle.getCassandraType().getKind() == CassandraType.Kind.UDT) { // Building CQL literals for TUPLE and UDT type is not supported diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index a961cd110c78..9bbdad28d956 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -55,7 +55,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.toOptional; -import static io.trino.plugin.cassandra.CassandraType.toCassandraType; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.ID_COLUMN_NAME; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.cqlNameToSqlName; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteral; @@ -80,16 +79,19 @@ public class CassandraMetadata private final boolean allowDropTable; private final JsonCodec> extraColumnMetadataCodec; + private final CassandraTypeManager cassandraTypeManager; @Inject public CassandraMetadata( CassandraSession cassandraSession, CassandraPartitionManager partitionManager, JsonCodec> extraColumnMetadataCodec, + CassandraTypeManager cassandraTypeManager, CassandraClientConfig config) { this.partitionManager = requireNonNull(partitionManager, "partitionManager is null"); this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); this.allowDropTable = requireNonNull(config, "config is null").getAllowDropTable(); this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); } @@ -221,6 +223,7 @@ public Optional> applyFilter(C } else { CassandraClusteringPredicatesExtractor clusteringPredicatesExtractor = new CassandraClusteringPredicatesExtractor( + cassandraTypeManager, cassandraSession.getTable(handle.getSchemaTableName()).getClusteringKeyColumns(), partitionResult.getUnenforcedConstraint(), cassandraSession.getCassandraVersion()); @@ -316,7 +319,7 @@ private CassandraOutputTableHandle createTable(ConnectorTableMetadata tableMetad queryBuilder.append(", ") .append(validColumnName(name)) .append(" ") - .append(toCassandraType(type, cassandraSession.getProtocolVersion()).getName().toLowerCase(ENGLISH)); + .append(cassandraTypeManager.toCassandraType(type, cassandraSession.getProtocolVersion()).getName().toLowerCase(ENGLISH)); } queryBuilder.append(") "); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java index 948b63758c2d..35cd7c581e91 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSink.java @@ -22,6 +22,7 @@ import com.datastax.oss.driver.api.querybuilder.term.Term; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import io.airlift.slice.Slice; @@ -71,6 +72,7 @@ public class CassandraPageSink implements ConnectorPageSink { + private final CassandraTypeManager cassandraTypeManager; private final CassandraSession cassandraSession; private final PreparedStatement insert; private final List columnTypes; @@ -80,6 +82,7 @@ public class CassandraPageSink private final BatchStatementBuilder batchStatement = BatchStatement.builder(DefaultBatchType.LOGGED); public CassandraPageSink( + CassandraTypeManager cassandraTypeManager, CassandraSession cassandraSession, ProtocolVersion protocolVersion, String schemaName, @@ -89,6 +92,7 @@ public CassandraPageSink( boolean generateUuid, int batchSize) { + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession"); requireNonNull(schemaName, "schemaName is null"); requireNonNull(tableName, "tableName is null"); @@ -185,6 +189,9 @@ else if (VARBINARY.equals(type)) { else if (UuidType.UUID.equals(type)) { values.add(trinoUuidToJavaUuid(type.getSlice(block, position))); } + else if (cassandraTypeManager.isIpAddressType(type)) { + values.add(InetAddresses.forString((String) type.getObjectValue(null, block, position))); + } else { throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSinkProvider.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSinkProvider.java index c53912ba9285..accf159e9014 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSinkProvider.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPageSinkProvider.java @@ -28,12 +28,17 @@ public class CassandraPageSinkProvider implements ConnectorPageSinkProvider { + private final CassandraTypeManager cassandraTypeManager; private final CassandraSession cassandraSession; private final int batchSize; @Inject - public CassandraPageSinkProvider(CassandraSession cassandraSession, CassandraClientConfig cassandraClientConfig) + public CassandraPageSinkProvider( + CassandraTypeManager cassandraTypeManager, + CassandraSession cassandraSession, + CassandraClientConfig cassandraClientConfig) { + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); this.batchSize = requireNonNull(cassandraClientConfig, "cassandraClientConfig is null").getBatchSize(); } @@ -46,6 +51,7 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa CassandraOutputTableHandle handle = (CassandraOutputTableHandle) tableHandle; return new CassandraPageSink( + cassandraTypeManager, cassandraSession, cassandraSession.getProtocolVersion(), handle.getSchemaName(), @@ -64,6 +70,7 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa CassandraInsertTableHandle handle = (CassandraInsertTableHandle) tableHandle; return new CassandraPageSink( + cassandraTypeManager, cassandraSession, cassandraSession.getProtocolVersion(), handle.getSchemaName(), diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPartitionManager.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPartitionManager.java index f5cb0164a62a..1e8f11db259b 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPartitionManager.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraPartitionManager.java @@ -45,11 +45,13 @@ public class CassandraPartitionManager private static final Logger log = Logger.get(CassandraPartitionManager.class); private final CassandraSession cassandraSession; + private final CassandraTypeManager cassandraTypeManager; @Inject - public CassandraPartitionManager(CassandraSession cassandraSession) + public CassandraPartitionManager(CassandraSession cassandraSession, CassandraTypeManager cassandraTypeManager) { this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); } public CassandraPartitionResult getPartitions(CassandraTableHandle cassandraTableHandle, TupleDomain tupleDomain) @@ -98,7 +100,7 @@ public CassandraPartitionResult getPartitions(CassandraTableHandle cassandraTabl if (column.isIndexed() && domain.isSingleValue()) { sb.append(CassandraCqlUtils.validColumnName(column.getName())) .append(" = ") - .append(column.getCassandraType().toCqlLiteral(entry.getValue().getSingleValue())); + .append(cassandraTypeManager.toCqlLiteral(column.getCassandraType(), entry.getValue().getSingleValue())); indexedColumns.add(column); // Only one indexed column predicate can be pushed down. break; @@ -132,7 +134,7 @@ private List getCassandraPartitions(CassandraTable table, Tu return cassandraSession.getPartitions(table, partitionKeysList); } - private static List> getPartitionKeysList(CassandraTable table, TupleDomain tupleDomain) + private List> getPartitionKeysList(CassandraTable table, TupleDomain tupleDomain) { ImmutableList.Builder> partitionColumnValues = ImmutableList.builder(); for (CassandraColumnHandle columnHandle : table.getPartitionKeyColumns()) { @@ -159,7 +161,7 @@ private static List> getPartitionKeysList(CassandraTable table, Tupl Object value = range.getSingleValue(); CassandraType valueType = columnHandle.getCassandraType(); - if (valueType.isSupportedPartitionKey()) { + if (cassandraTypeManager.isSupportedPartitionKey(valueType.getKind())) { columnValues.add(value); } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java index 986a48eb82a1..164e876ab909 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordCursor.java @@ -32,12 +32,14 @@ public class CassandraRecordCursor implements RecordCursor { private final List cassandraTypes; + private final CassandraTypeManager cassandraTypeManager; private final ResultSet rs; private Row currentRow; - public CassandraRecordCursor(CassandraSession cassandraSession, List cassandraTypes, String cql) + public CassandraRecordCursor(CassandraSession cassandraSession, CassandraTypeManager cassandraTypeManager, List cassandraTypes, String cql) { this.cassandraTypes = cassandraTypes; + this.cassandraTypeManager = cassandraTypeManager; rs = cassandraSession.execute(cql); currentRow = null; } @@ -126,7 +128,7 @@ public Slice getSlice(int i) if (getCassandraType(i).getKind() == Kind.TIMESTAMP) { throw new IllegalArgumentException("Timestamp column can not be accessed with getSlice"); } - NullableValue value = cassandraTypes.get(i).getColumnValue(currentRow, i); + NullableValue value = cassandraTypeManager.getColumnValue(cassandraTypes.get(i), currentRow, i); if (value.getValue() instanceof Slice) { return (Slice) value.getValue(); } @@ -140,7 +142,7 @@ public Object getObject(int i) switch (cassandraType.getKind()) { case TUPLE: case UDT: - return cassandraType.getColumnValue(currentRow, i).getValue(); + return cassandraTypeManager.getColumnValue(cassandraType, currentRow, i).getValue(); default: throw new IllegalArgumentException("getObject cannot be called for " + cassandraType); } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSet.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSet.java index f7ff61121ffc..7ae25cbefe69 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSet.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSet.java @@ -28,13 +28,15 @@ public class CassandraRecordSet implements RecordSet { private final CassandraSession cassandraSession; + private final CassandraTypeManager cassandraTypeManager; private final String cql; private final List cassandraTypes; private final List columnTypes; - public CassandraRecordSet(CassandraSession cassandraSession, String cql, List cassandraColumns) + public CassandraRecordSet(CassandraSession cassandraSession, CassandraTypeManager cassandraTypeManager, String cql, List cassandraColumns) { this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); this.cql = requireNonNull(cql, "cql is null"); requireNonNull(cassandraColumns, "cassandraColumns is null"); @@ -51,7 +53,7 @@ public List getColumnTypes() @Override public RecordCursor cursor() { - return new CassandraRecordCursor(cassandraSession, cassandraTypes, cql); + return new CassandraRecordCursor(cassandraSession, cassandraTypeManager, cassandraTypes, cql); } private static List transformList(List list, Function function) diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSetProvider.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSetProvider.java index 178272758827..dd1b4fcaf531 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSetProvider.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraRecordSetProvider.java @@ -36,11 +36,13 @@ public class CassandraRecordSetProvider private static final Logger log = Logger.get(CassandraRecordSetProvider.class); private final CassandraSession cassandraSession; + private final CassandraTypeManager cassandraTypeManager; @Inject - public CassandraRecordSetProvider(CassandraSession cassandraSession) + public CassandraRecordSetProvider(CassandraSession cassandraSession, CassandraTypeManager cassandraTypeManager) { this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); + this.cassandraTypeManager = requireNonNull(cassandraTypeManager); } @Override @@ -62,6 +64,6 @@ public RecordSet getRecordSet(ConnectorTransactionHandle transaction, ConnectorS String cql = sb.toString(); log.debug("Creating record set: %s", cql); - return new CassandraRecordSet(cassandraSession, cql, cassandraColumns); + return new CassandraRecordSet(cassandraSession, cassandraTypeManager, cql, cassandraColumns); } } diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java index 12ae07bd764d..ec0e8fcec1f1 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java @@ -82,8 +82,6 @@ import static com.google.common.collect.Iterables.transform; import static io.trino.plugin.cassandra.CassandraErrorCode.CASSANDRA_VERSION_ERROR; import static io.trino.plugin.cassandra.CassandraMetadata.PRESTO_COMMENT_METADATA; -import static io.trino.plugin.cassandra.CassandraType.isFullySupported; -import static io.trino.plugin.cassandra.CassandraType.toCassandraType; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.selectDistinctFrom; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validSchemaName; import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validTableName; @@ -104,12 +102,18 @@ public class CassandraSession private static final String SIZE_ESTIMATES = "size_estimates"; private static final Version PARTITION_FETCH_WITH_IN_PREDICATE_VERSION = Version.parse("2.2"); + private final CassandraTypeManager cassandraTypeManager; private final JsonCodec> extraColumnMetadataCodec; private final Supplier session; private final Duration noHostAvailableRetryTimeout; - public CassandraSession(JsonCodec> extraColumnMetadataCodec, Supplier sessionSupplier, Duration noHostAvailableRetryTimeout) + public CassandraSession( + CassandraTypeManager cassandraTypeManager, + JsonCodec> extraColumnMetadataCodec, + Supplier sessionSupplier, + Duration noHostAvailableRetryTimeout) { + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); this.noHostAvailableRetryTimeout = requireNonNull(noHostAvailableRetryTimeout, "noHostAvailableRetryTimeout is null"); this.session = memoize(sessionSupplier::get); @@ -342,7 +346,7 @@ private static void checkColumnNames(Collection columns) private Optional buildColumnHandle(RelationMetadata tableMetadata, ColumnMetadata columnMeta, boolean partitionKey, boolean clusteringKey, int ordinalPosition, boolean hidden) { - Optional cassandraType = toCassandraType(columnMeta.getType()); + Optional cassandraType = cassandraTypeManager.toCassandraType(columnMeta.getType()); if (cassandraType.isEmpty()) { log.debug("Unsupported column type: %s", columnMeta.getType().asCql(false, false)); return Optional.empty(); @@ -350,7 +354,7 @@ private Optional buildColumnHandle(RelationMetadata table List typeArgs = getTypeArguments(columnMeta.getType()); for (DataType typeArgument : typeArgs) { - if (!isFullySupported(typeArgument)) { + if (!cassandraTypeManager.isFullySupported(typeArgument)) { log.debug("%s column has unsupported type: %s", columnMeta.getName(), typeArgument); return Optional.empty(); } @@ -420,14 +424,14 @@ public List getPartitions(CassandraTable table, List 0) { stringBuilder.append(" AND "); } stringBuilder.append(CassandraCqlUtils.validColumnName(columnHandle.getName())); stringBuilder.append(" = "); - stringBuilder.append(columnHandle.getCassandraType().getColumnValueForCql(row, i)); + stringBuilder.append(cassandraTypeManager.getColumnValueForCql(columnHandle.getCassandraType(), row, i)); } buffer.flip(); byte[] key = new byte[buffer.limit()]; @@ -492,7 +496,7 @@ private Iterable queryPartitionKeysLegacyWithMultipleQueries(CassandraTable return rowList.build(); } - private static List getInRelations(List partitionKeyColumns, List> filterPrefixes) + private List getInRelations(List partitionKeyColumns, List> filterPrefixes) { return IntStream .range(0, Math.min(partitionKeyColumns.size(), filterPrefixes.size())) @@ -500,24 +504,24 @@ private static List getInRelations(List partiti .collect(toImmutableList()); } - private static Relation getInRelation(CassandraColumnHandle column, Set filterPrefixes) + private Relation getInRelation(CassandraColumnHandle column, Set filterPrefixes) { List values = filterPrefixes .stream() - .map(value -> column.getCassandraType().getJavaValue(value)) + .map(value -> cassandraTypeManager.getJavaValue(column.getCassandraType().getKind(), value)) .map(QueryBuilder::literal) .collect(toList()); return Relation.column(CassandraCqlUtils.validColumnName(column.getName())).in(values); } - private static List getEqualityRelations(List partitionKeyColumns, List filterPrefix) + private List getEqualityRelations(List partitionKeyColumns, List filterPrefix) { return IntStream .range(0, Math.min(partitionKeyColumns.size(), filterPrefix.size())) .mapToObj(i -> { CassandraColumnHandle column = partitionKeyColumns.get(i); - Object value = column.getCassandraType().getJavaValue(filterPrefix.get(i)); + Object value = cassandraTypeManager.getJavaValue(column.getCassandraType().getKind(), filterPrefix.get(i)); return Relation.column(CassandraCqlUtils.validColumnName(column.getName())).isEqualTo(literal(value)); }) .collect(toImmutableList()); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java index 7fe83552ed15..511cd15c3528 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSplitManager.java @@ -57,18 +57,21 @@ public class CassandraSplitManager private final int partitionSizeForBatchSelect; private final CassandraTokenSplitManager tokenSplitMgr; private final CassandraPartitionManager partitionManager; + private final CassandraTypeManager cassandraTypeManager; @Inject public CassandraSplitManager( CassandraClientConfig cassandraClientConfig, CassandraSession cassandraSession, CassandraTokenSplitManager tokenSplitMgr, - CassandraPartitionManager partitionManager) + CassandraPartitionManager partitionManager, + CassandraTypeManager cassandraTypeManager) { this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); this.partitionSizeForBatchSelect = cassandraClientConfig.getPartitionSizeForBatchSelect(); this.tokenSplitMgr = tokenSplitMgr; this.partitionManager = requireNonNull(partitionManager, "partitionManager is null"); + this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null"); } @Override @@ -114,13 +117,14 @@ public ConnectorSplitSource getSplits( return new FixedSplitSource(splits); } - private static String extractClusteringKeyPredicates(CassandraPartitionResult partitionResult, CassandraTableHandle tableHandle, CassandraSession session) + private String extractClusteringKeyPredicates(CassandraPartitionResult partitionResult, CassandraTableHandle tableHandle, CassandraSession session) { if (partitionResult.isUnpartitioned()) { return ""; } CassandraClusteringPredicatesExtractor clusteringPredicatesExtractor = new CassandraClusteringPredicatesExtractor( + cassandraTypeManager, session.getTable(tableHandle.getSchemaTableName()).getClusteringKeyColumns(), partitionResult.getUnenforcedConstraint(), session.getCassandraVersion()); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java index d16b00c7dcaa..658368a10cf1 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraType.java @@ -13,79 +13,14 @@ */ package io.trino.plugin.cassandra; -import com.datastax.oss.driver.api.core.CqlIdentifier; -import com.datastax.oss.driver.api.core.ProtocolVersion; -import com.datastax.oss.driver.api.core.cql.Row; -import com.datastax.oss.driver.api.core.data.GettableByIndex; -import com.datastax.oss.driver.api.core.data.TupleValue; -import com.datastax.oss.driver.api.core.data.UdtValue; -import com.datastax.oss.driver.api.core.type.DataType; -import com.datastax.oss.driver.api.core.type.ListType; -import com.datastax.oss.driver.api.core.type.MapType; -import com.datastax.oss.driver.api.core.type.SetType; -import com.datastax.oss.driver.api.core.type.TupleType; -import com.datastax.oss.driver.api.core.type.UserDefinedType; -import com.datastax.oss.protocol.internal.ProtocolConstants; -import com.datastax.oss.protocol.internal.util.Bytes; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.net.InetAddresses; -import io.airlift.slice.Slice; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.RowBlockBuilder; -import io.trino.spi.block.SingleRowBlockWriter; -import io.trino.spi.predicate.NullableValue; -import io.trino.spi.type.BigintType; -import io.trino.spi.type.BooleanType; -import io.trino.spi.type.DateType; -import io.trino.spi.type.DoubleType; -import io.trino.spi.type.IntegerType; -import io.trino.spi.type.RealType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.SmallintType; -import io.trino.spi.type.TimeZoneKey; -import io.trino.spi.type.TimestampWithTimeZoneType; -import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; -import io.trino.spi.type.UuidType; -import io.trino.spi.type.VarbinaryType; -import io.trino.spi.type.VarcharType; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.time.Instant; -import java.time.LocalDate; -import java.util.Arrays; -import java.util.Collection; import java.util.List; -import java.util.Map; import java.util.Objects; -import java.util.Optional; -import java.util.function.Supplier; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.net.InetAddresses.toAddrString; -import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.slice.Slices.wrappedBuffer; -import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteral; -import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteralForJson; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; -import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; -import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; -import static io.trino.spi.type.TypeUtils.writeNativeValue; -import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; -import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; -import static java.lang.Float.floatToRawIntBits; -import static java.lang.Float.intBitsToFloat; -import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -165,555 +100,6 @@ public String getName() return kind.name(); } - public static Optional toCassandraType(DataType dataType) - { - switch (dataType.getProtocolCode()) { - case ProtocolConstants.DataType.ASCII: - return Optional.of(CassandraTypes.ASCII); - case ProtocolConstants.DataType.BIGINT: - return Optional.of(CassandraTypes.BIGINT); - case ProtocolConstants.DataType.BLOB: - return Optional.of(CassandraTypes.BLOB); - case ProtocolConstants.DataType.BOOLEAN: - return Optional.of(CassandraTypes.BOOLEAN); - case ProtocolConstants.DataType.COUNTER: - return Optional.of(CassandraTypes.COUNTER); - case ProtocolConstants.DataType.CUSTOM: - return Optional.of(CassandraTypes.CUSTOM); - case ProtocolConstants.DataType.DATE: - return Optional.of(CassandraTypes.DATE); - case ProtocolConstants.DataType.DECIMAL: - return Optional.of(CassandraTypes.DECIMAL); - case ProtocolConstants.DataType.DOUBLE: - return Optional.of(CassandraTypes.DOUBLE); - case ProtocolConstants.DataType.FLOAT: - return Optional.of(CassandraTypes.FLOAT); - case ProtocolConstants.DataType.INET: - return Optional.of(CassandraTypes.INET); - case ProtocolConstants.DataType.INT: - return Optional.of(CassandraTypes.INT); - case ProtocolConstants.DataType.LIST: - return Optional.of(CassandraTypes.LIST); - case ProtocolConstants.DataType.MAP: - return Optional.of(CassandraTypes.MAP); - case ProtocolConstants.DataType.SET: - return Optional.of(CassandraTypes.SET); - case ProtocolConstants.DataType.SMALLINT: - return Optional.of(CassandraTypes.SMALLINT); - case ProtocolConstants.DataType.TIMESTAMP: - return Optional.of(CassandraTypes.TIMESTAMP); - case ProtocolConstants.DataType.TIMEUUID: - return Optional.of(CassandraTypes.TIMEUUID); - case ProtocolConstants.DataType.TINYINT: - return Optional.of(CassandraTypes.TINYINT); - case ProtocolConstants.DataType.TUPLE: - return createTypeForTuple(dataType); - case ProtocolConstants.DataType.UDT: - return createTypeForUserType(dataType); - case ProtocolConstants.DataType.UUID: - return Optional.of(CassandraTypes.UUID); - case ProtocolConstants.DataType.VARCHAR: - return Optional.of(CassandraTypes.VARCHAR); - case ProtocolConstants.DataType.VARINT: - return Optional.of(CassandraTypes.VARINT); - default: - return Optional.empty(); - } - } - - private static Optional createTypeForTuple(DataType dataType) - { - TupleType tupleType = (TupleType) dataType; - List> argumentTypesOptionals = tupleType.getComponentTypes().stream() - .map(CassandraType::toCassandraType) - .collect(toImmutableList()); - - if (argumentTypesOptionals.stream().anyMatch(Optional::isEmpty)) { - return Optional.empty(); - } - - List argumentTypes = argumentTypesOptionals.stream() - .map(Optional::get) - .collect(toImmutableList()); - - RowType trinoType = RowType.anonymous( - argumentTypes.stream() - .map(CassandraType::getTrinoType) - .collect(toImmutableList())); - - return Optional.of(new CassandraType(Kind.TUPLE, trinoType, argumentTypes)); - } - - private static Optional createTypeForUserType(DataType dataType) - { - UserDefinedType userDefinedType = (UserDefinedType) dataType; - // Using ImmutableMap is important as we exploit the fact that entries iteration order matches the order of putting values via builder - ImmutableMap.Builder argumentTypes = ImmutableMap.builder(); - - List fieldNames = userDefinedType.getFieldNames(); - List fieldTypes = userDefinedType.getFieldTypes(); - if (fieldNames.size() != fieldTypes.size()) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Mismatch between the number of field names (%s) and the number of field types (%s) for the data type %s", fieldNames.size(), fieldTypes.size(), dataType)); - } - for (int i = 0; i < fieldNames.size(); i++) { - Optional cassandraType = CassandraType.toCassandraType(fieldTypes.get(i)); - if (cassandraType.isEmpty()) { - return Optional.empty(); - } - argumentTypes.put(fieldNames.get(i).toString(), cassandraType.get()); - } - - RowType trinoType = RowType.from( - argumentTypes.buildOrThrow().entrySet().stream() - .map(field -> new RowType.Field(Optional.of(field.getKey()), field.getValue().getTrinoType())) - .collect(toImmutableList())); - - return Optional.of(new CassandraType(Kind.UDT, trinoType, argumentTypes.buildOrThrow().values().stream().collect(toImmutableList()))); - } - - public NullableValue getColumnValue(Row row, int position) - { - return getColumnValue(row, position, () -> row.getColumnDefinitions().get(position).getType()); - } - - public NullableValue getColumnValue(GettableByIndex row, int position, Supplier dataTypeSupplier) - { - if (row.isNull(position)) { - return NullableValue.asNull(trinoType); - } - - switch (kind) { - case ASCII: - case TEXT: - case VARCHAR: - return NullableValue.of(trinoType, utf8Slice(row.getString(position))); - case INT: - return NullableValue.of(trinoType, (long) row.getInt(position)); - case SMALLINT: - return NullableValue.of(trinoType, (long) row.getShort(position)); - case TINYINT: - return NullableValue.of(trinoType, (long) row.getByte(position)); - case BIGINT: - case COUNTER: - return NullableValue.of(trinoType, row.getLong(position)); - case BOOLEAN: - return NullableValue.of(trinoType, row.getBoolean(position)); - case DOUBLE: - return NullableValue.of(trinoType, row.getDouble(position)); - case FLOAT: - return NullableValue.of(trinoType, (long) floatToRawIntBits(row.getFloat(position))); - case DECIMAL: - return NullableValue.of(trinoType, row.getBigDecimal(position).doubleValue()); - case UUID: - case TIMEUUID: - return NullableValue.of(trinoType, javaUuidToTrinoUuid(row.getUuid(position))); - case TIMESTAMP: - return NullableValue.of(trinoType, packDateTimeWithZone(row.getInstant(position).toEpochMilli(), TimeZoneKey.UTC_KEY)); - case DATE: - return NullableValue.of(trinoType, row.getLocalDate(position).toEpochDay()); - case INET: - return NullableValue.of(trinoType, utf8Slice(toAddrString(row.getInetAddress(position)))); - case VARINT: - return NullableValue.of(trinoType, utf8Slice(row.getBigInteger(position).toString())); - case BLOB: - case CUSTOM: - return NullableValue.of(trinoType, wrappedBuffer(row.getBytesUnsafe(position))); - case SET: - return NullableValue.of(trinoType, utf8Slice(buildArrayValueFromSetType(row, position, dataTypeSupplier.get()))); - case LIST: - return NullableValue.of(trinoType, utf8Slice(buildArrayValueFromListType(row, position, dataTypeSupplier.get()))); - case MAP: - return NullableValue.of(trinoType, utf8Slice(buildMapValue(row, position, dataTypeSupplier.get()))); - case TUPLE: - return NullableValue.of(trinoType, buildTupleValue(row, position)); - case UDT: - return NullableValue.of(trinoType, buildUserTypeValue(row, position)); - } - throw new IllegalStateException("Handling of type " + this + " is not implemented"); - } - - private static String buildMapValue(GettableByIndex row, int position, DataType dataType) - { - checkArgument(dataType instanceof MapType, "Expected to deal with an instance of %s class, got: %s", MapType.class, dataType); - MapType mapType = (MapType) dataType; - return buildMapValue((Map) row.getObject(position), mapType.getKeyType(), mapType.getValueType()); - } - - private static String buildMapValue(Map cassandraMap, DataType keyType, DataType valueType) - { - StringBuilder sb = new StringBuilder(); - sb.append("{"); - for (Map.Entry entry : cassandraMap.entrySet()) { - if (sb.length() > 1) { - sb.append(","); - } - sb.append(objectToJson(entry.getKey(), keyType)); - sb.append(":"); - sb.append(objectToJson(entry.getValue(), valueType)); - } - sb.append("}"); - return sb.toString(); - } - - private static String buildArrayValueFromSetType(GettableByIndex row, int position, DataType type) - { - checkArgument(type instanceof SetType, "Expected to deal with an instance of %s class, got: %s", SetType.class, type); - SetType setType = (SetType) type; - return buildArrayValue((Collection) row.getObject(position), setType.getElementType()); - } - - private static String buildArrayValueFromListType(GettableByIndex row, int position, DataType type) - { - checkArgument(type instanceof ListType, "Expected to deal with an instance of %s class, got: %s", ListType.class, type); - ListType listType = (ListType) type; - return buildArrayValue((Collection) row.getObject(position), listType.getElementType()); - } - - @VisibleForTesting - static String buildArrayValue(Collection cassandraCollection, DataType elementType) - { - StringBuilder sb = new StringBuilder(); - sb.append("["); - for (Object value : cassandraCollection) { - if (sb.length() > 1) { - sb.append(","); - } - sb.append(objectToJson(value, elementType)); - } - sb.append("]"); - return sb.toString(); - } - - private Block buildTupleValue(GettableByIndex row, int position) - { - verify(this.kind == Kind.TUPLE, "Not a TUPLE type"); - TupleValue tupleValue = row.getTupleValue(position); - RowBlockBuilder blockBuilder = (RowBlockBuilder) this.trinoType.createBlockBuilder(null, 1); - SingleRowBlockWriter singleRowBlockWriter = blockBuilder.beginBlockEntry(); - int tuplePosition = 0; - for (CassandraType argumentType : this.getArgumentTypes()) { - int finalTuplePosition = tuplePosition; - NullableValue value = argumentType.getColumnValue(tupleValue, tuplePosition, () -> tupleValue.getType().getComponentTypes().get(finalTuplePosition)); - writeNativeValue(argumentType.getTrinoType(), singleRowBlockWriter, value.getValue()); - tuplePosition++; - } - // can I just return singleRowBlockWriter here? It extends AbstractSingleRowBlock and tests pass. - blockBuilder.closeEntry(); - return (Block) this.trinoType.getObject(blockBuilder, 0); - } - - private Block buildUserTypeValue(GettableByIndex row, int position) - { - verify(this.kind == Kind.UDT, "Not a user defined type: %s", this.kind); - UdtValue udtValue = row.getUdtValue(position); - RowBlockBuilder blockBuilder = (RowBlockBuilder) this.trinoType.createBlockBuilder(null, 1); - SingleRowBlockWriter singleRowBlockWriter = blockBuilder.beginBlockEntry(); - int tuplePosition = 0; - List udtTypeFieldTypes = udtValue.getType().getFieldTypes(); - for (CassandraType argumentType : this.getArgumentTypes()) { - int finalTuplePosition = tuplePosition; - NullableValue value = argumentType.getColumnValue(udtValue, tuplePosition, () -> udtTypeFieldTypes.get(finalTuplePosition)); - writeNativeValue(argumentType.getTrinoType(), singleRowBlockWriter, value.getValue()); - tuplePosition++; - } - - blockBuilder.closeEntry(); - return (Block) this.trinoType.getObject(blockBuilder, 0); - } - - // TODO unify with toCqlLiteral - public String getColumnValueForCql(Row row, int position) - { - if (row.isNull(position)) { - return null; - } - - switch (kind) { - case ASCII: - case TEXT: - case VARCHAR: - return quoteStringLiteral(row.getString(position)); - case INT: - return Integer.toString(row.getInt(position)); - case SMALLINT: - return Short.toString(row.getShort(position)); - case TINYINT: - return Byte.toString(row.getByte(position)); - case BIGINT: - case COUNTER: - return Long.toString(row.getLong(position)); - case BOOLEAN: - return Boolean.toString(row.getBool(position)); - case DOUBLE: - return Double.toString(row.getDouble(position)); - case FLOAT: - return Float.toString(row.getFloat(position)); - case DECIMAL: - return row.getBigDecimal(position).toString(); - case UUID: - case TIMEUUID: - return row.getUuid(position).toString(); - case TIMESTAMP: - return Long.toString(row.getInstant(position).toEpochMilli()); - case DATE: - return quoteStringLiteral(row.getLocalDate(position).toString()); - case INET: - return quoteStringLiteral(toAddrString(row.getInetAddress(position))); - case VARINT: - return row.getBigInteger(position).toString(); - case BLOB: - case CUSTOM: - return Bytes.toHexString(row.getBytesUnsafe(position)); - - case LIST: - case SET: - case MAP: - case TUPLE: - case UDT: - // unsupported - break; - } - throw new IllegalStateException("Handling of type " + this + " is not implemented"); - } - - // TODO unify with getColumnValueForCql - public String toCqlLiteral(Object trinoNativeValue) - { - if (kind == Kind.DATE) { - LocalDate date = LocalDate.ofEpochDay(toIntExact((long) trinoNativeValue)); - return quoteStringLiteral(date.toString()); - } - if (kind == Kind.TIMESTAMP) { - return String.valueOf(unpackMillisUtc((Long) trinoNativeValue)); - } - - String value; - if (trinoNativeValue instanceof Slice) { - value = ((Slice) trinoNativeValue).toStringUtf8(); - } - else { - value = trinoNativeValue.toString(); - } - - switch (kind) { - case ASCII: - case TEXT: - case VARCHAR: - return quoteStringLiteral(value); - case INET: - // remove '/' in the string. e.g. /127.0.0.1 - return quoteStringLiteral(value.substring(1)); - default: - return value; - } - } - - private static String objectToJson(Object cassandraValue, DataType dataType) - { - CassandraType cassandraType = toCassandraType(dataType) - .orElseThrow(() -> new IllegalStateException("Unsupported type: " + dataType)); - - switch (cassandraType.kind) { - case ASCII: - case TEXT: - case VARCHAR: - case UUID: - case TIMEUUID: - case TIMESTAMP: - case DATE: - case INET: - case VARINT: - case TUPLE: - case UDT: - return quoteStringLiteralForJson(cassandraValue.toString()); - - case BLOB: - case CUSTOM: - return quoteStringLiteralForJson(Bytes.toHexString((ByteBuffer) cassandraValue)); - - case SMALLINT: - case TINYINT: - case INT: - case BIGINT: - case COUNTER: - case BOOLEAN: - case DOUBLE: - case FLOAT: - case DECIMAL: - return cassandraValue.toString(); - case LIST: - checkArgument(dataType instanceof ListType, "Expected to deal with an instance of %s class, got: %s", ListType.class, dataType); - ListType listType = (ListType) dataType; - return buildArrayValue((Collection) cassandraValue, listType.getElementType()); - case SET: - checkArgument(dataType instanceof SetType, "Expected to deal with an instance of %s class, got: %s", SetType.class, dataType); - SetType setType = (SetType) dataType; - return buildArrayValue((Collection) cassandraValue, setType.getElementType()); - case MAP: - checkArgument(dataType instanceof MapType, "Expected to deal with an instance of %s class, got: %s", MapType.class, dataType); - MapType mapType = (MapType) dataType; - return buildMapValue((Map) cassandraValue, mapType.getKeyType(), mapType.getValueType()); - } - throw new IllegalStateException("Unsupported type: " + cassandraType); - } - - public Object getJavaValue(Object trinoNativeValue) - { - switch (kind) { - case ASCII: - case TEXT: - case VARCHAR: - return ((Slice) trinoNativeValue).toStringUtf8(); - case BIGINT: - case BOOLEAN: - case DOUBLE: - case COUNTER: - return trinoNativeValue; - case INET: - return InetAddresses.forString(((Slice) trinoNativeValue).toStringUtf8()); - case INT: - case SMALLINT: - case TINYINT: - return ((Long) trinoNativeValue).intValue(); - case FLOAT: - // conversion can result in precision lost - return intBitsToFloat(((Long) trinoNativeValue).intValue()); - case DECIMAL: - // conversion can result in precision lost - // Trino uses double for decimal, so to keep the floating point precision, convert it to string. - // Otherwise partition id doesn't match - return new BigDecimal(trinoNativeValue.toString()); - case TIMESTAMP: - return Instant.ofEpochMilli(unpackMillisUtc((Long) trinoNativeValue)); - case DATE: - return LocalDate.ofEpochDay(((Long) trinoNativeValue).intValue()); - case UUID: - case TIMEUUID: - return trinoUuidToJavaUuid((Slice) trinoNativeValue); - case BLOB: - case CUSTOM: - case TUPLE: - case UDT: - return ((Slice) trinoNativeValue).toStringUtf8(); - case VARINT: - return new BigInteger(((Slice) trinoNativeValue).toStringUtf8()); - case SET: - case LIST: - case MAP: - } - throw new IllegalStateException("Back conversion not implemented for " + this); - } - - public boolean isSupportedPartitionKey() - { - switch (kind) { - case ASCII: - case TEXT: - case VARCHAR: - case BIGINT: - case BOOLEAN: - case DOUBLE: - case INET: - case INT: - case TINYINT: - case SMALLINT: - case FLOAT: - case DECIMAL: - case DATE: - case TIMESTAMP: - case UUID: - case TIMEUUID: - return true; - case COUNTER: - case BLOB: - case CUSTOM: - case VARINT: - case SET: - case LIST: - case MAP: - case TUPLE: - case UDT: - default: - return false; - } - } - - public static boolean isFullySupported(DataType dataType) - { - if (toCassandraType(dataType).isEmpty()) { - return false; - } - - if (dataType instanceof UserDefinedType) { - return ((UserDefinedType) dataType).getFieldTypes().stream() - .allMatch(CassandraType::isFullySupported); - } - - if (dataType instanceof MapType) { - MapType mapType = (MapType) dataType; - return Arrays.stream(new DataType[] {mapType.getKeyType(), mapType.getValueType()}) - .allMatch(CassandraType::isFullySupported); - } - - if (dataType instanceof ListType) { - return CassandraType.isFullySupported(((ListType) dataType).getElementType()); - } - - if (dataType instanceof TupleType) { - return ((TupleType) dataType).getComponentTypes().stream() - .allMatch(CassandraType::isFullySupported); - } - - if (dataType instanceof SetType) { - return CassandraType.isFullySupported(((SetType) dataType).getElementType()); - } - - return true; - } - - public static CassandraType toCassandraType(Type type, ProtocolVersion protocolVersion) - { - if (type.equals(BooleanType.BOOLEAN)) { - return CassandraTypes.BOOLEAN; - } - if (type.equals(BigintType.BIGINT)) { - return CassandraTypes.BIGINT; - } - if (type.equals(IntegerType.INTEGER)) { - return CassandraTypes.INT; - } - if (type.equals(SmallintType.SMALLINT)) { - return CassandraTypes.SMALLINT; - } - if (type.equals(TinyintType.TINYINT)) { - return CassandraTypes.TINYINT; - } - if (type.equals(DoubleType.DOUBLE)) { - return CassandraTypes.DOUBLE; - } - if (type.equals(RealType.REAL)) { - return CassandraTypes.FLOAT; - } - if (type instanceof VarcharType) { - return CassandraTypes.TEXT; - } - if (type.equals(DateType.DATE)) { - return protocolVersion.getCode() <= ProtocolVersion.V3.getCode() - ? CassandraTypes.TEXT - : CassandraTypes.DATE; - } - if (type.equals(VarbinaryType.VARBINARY)) { - return CassandraTypes.BLOB; - } - if (type.equals(TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS)) { - return CassandraTypes.TIMESTAMP; - } - if (type.equals(UuidType.UUID)) { - return CassandraTypes.UUID; - } - throw new TrinoException(NOT_SUPPORTED, "Unsupported type: " + type); - } - @Override public boolean equals(Object o) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTypeManager.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTypeManager.java new file mode 100644 index 000000000000..6e59573e90b0 --- /dev/null +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraTypeManager.java @@ -0,0 +1,706 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.cassandra; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.data.GettableByIndex; +import com.datastax.oss.driver.api.core.data.TupleValue; +import com.datastax.oss.driver.api.core.data.UdtValue; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.ListType; +import com.datastax.oss.driver.api.core.type.MapType; +import com.datastax.oss.driver.api.core.type.SetType; +import com.datastax.oss.driver.api.core.type.TupleType; +import com.datastax.oss.driver.api.core.type.UserDefinedType; +import com.datastax.oss.protocol.internal.ProtocolConstants; +import com.datastax.oss.protocol.internal.util.Bytes; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; +import com.google.inject.Inject; +import io.airlift.slice.Slice; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.block.SingleRowBlockWriter; +import io.trino.spi.predicate.NullableValue; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DateType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.TimeZoneKey; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeManager; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.UuidType; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.nio.ByteBuffer; +import java.time.Instant; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.net.InetAddresses.toAddrString; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.plugin.cassandra.CassandraType.Kind.DATE; +import static io.trino.plugin.cassandra.CassandraType.Kind.TIMESTAMP; +import static io.trino.plugin.cassandra.CassandraType.Kind.TUPLE; +import static io.trino.plugin.cassandra.CassandraType.Kind.UDT; +import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteral; +import static io.trino.plugin.cassandra.util.CassandraCqlUtils.quoteStringLiteralForJson; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; +import static io.trino.spi.type.TypeUtils.writeNativeValue; +import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; +import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.lang.System.arraycopy; +import static java.util.Objects.requireNonNull; + +public class CassandraTypeManager +{ + private final Type ipAddressType; + + @Inject + public CassandraTypeManager(TypeManager typeManager) + { + requireNonNull(typeManager, "typeManager is null"); + this.ipAddressType = typeManager.getType(new TypeSignature(StandardTypes.IPADDRESS)); + } + + public Optional toCassandraType(DataType dataType) + { + switch (dataType.getProtocolCode()) { + case ProtocolConstants.DataType.ASCII: + return Optional.of(CassandraTypes.ASCII); + case ProtocolConstants.DataType.BIGINT: + return Optional.of(CassandraTypes.BIGINT); + case ProtocolConstants.DataType.BLOB: + return Optional.of(CassandraTypes.BLOB); + case ProtocolConstants.DataType.BOOLEAN: + return Optional.of(CassandraTypes.BOOLEAN); + case ProtocolConstants.DataType.COUNTER: + return Optional.of(CassandraTypes.COUNTER); + case ProtocolConstants.DataType.CUSTOM: + return Optional.of(CassandraTypes.CUSTOM); + case ProtocolConstants.DataType.DATE: + return Optional.of(CassandraTypes.DATE); + case ProtocolConstants.DataType.DECIMAL: + return Optional.of(CassandraTypes.DECIMAL); + case ProtocolConstants.DataType.DOUBLE: + return Optional.of(CassandraTypes.DOUBLE); + case ProtocolConstants.DataType.FLOAT: + return Optional.of(CassandraTypes.FLOAT); + case ProtocolConstants.DataType.INET: + return Optional.of(new CassandraType( + CassandraType.Kind.INET, + ipAddressType)); + case ProtocolConstants.DataType.INT: + return Optional.of(CassandraTypes.INT); + case ProtocolConstants.DataType.LIST: + return Optional.of(CassandraTypes.LIST); + case ProtocolConstants.DataType.MAP: + return Optional.of(CassandraTypes.MAP); + case ProtocolConstants.DataType.SET: + return Optional.of(CassandraTypes.SET); + case ProtocolConstants.DataType.SMALLINT: + return Optional.of(CassandraTypes.SMALLINT); + case ProtocolConstants.DataType.TIMESTAMP: + return Optional.of(CassandraTypes.TIMESTAMP); + case ProtocolConstants.DataType.TIMEUUID: + return Optional.of(CassandraTypes.TIMEUUID); + case ProtocolConstants.DataType.TINYINT: + return Optional.of(CassandraTypes.TINYINT); + case ProtocolConstants.DataType.TUPLE: + return createTypeForTuple(dataType); + case ProtocolConstants.DataType.UDT: + return createTypeForUserType(dataType); + case ProtocolConstants.DataType.UUID: + return Optional.of(CassandraTypes.UUID); + case ProtocolConstants.DataType.VARCHAR: + return Optional.of(CassandraTypes.VARCHAR); + case ProtocolConstants.DataType.VARINT: + return Optional.of(CassandraTypes.VARINT); + default: + return Optional.empty(); + } + } + + private Optional createTypeForTuple(DataType dataType) + { + TupleType tupleType = (TupleType) dataType; + List> argumentTypesOptionals = tupleType.getComponentTypes().stream() + .map(componentType -> toCassandraType(componentType)) + .collect(toImmutableList()); + + if (argumentTypesOptionals.stream().anyMatch(Optional::isEmpty)) { + return Optional.empty(); + } + + List argumentTypes = argumentTypesOptionals.stream() + .map(Optional::get) + .collect(toImmutableList()); + + RowType trinoType = RowType.anonymous( + argumentTypes.stream() + .map(CassandraType::getTrinoType) + .collect(toImmutableList())); + + return Optional.of(new CassandraType(TUPLE, trinoType, argumentTypes)); + } + + private Optional createTypeForUserType(DataType dataType) + { + UserDefinedType userDefinedType = (UserDefinedType) dataType; + // Using ImmutableMap is important as we exploit the fact that entries iteration order matches the order of putting values via builder + ImmutableMap.Builder argumentTypes = ImmutableMap.builder(); + + List fieldNames = userDefinedType.getFieldNames(); + List fieldTypes = userDefinedType.getFieldTypes(); + if (fieldNames.size() != fieldTypes.size()) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Mismatch between the number of field names (%s) and the number of field types (%s) for the data type %s", fieldNames.size(), fieldTypes.size(), dataType)); + } + for (int i = 0; i < fieldNames.size(); i++) { + Optional cassandraType = toCassandraType(fieldTypes.get(i)); + if (cassandraType.isEmpty()) { + return Optional.empty(); + } + argumentTypes.put(fieldNames.get(i).toString(), cassandraType.get()); + } + + RowType trinoType = RowType.from( + argumentTypes.buildOrThrow().entrySet().stream() + .map(field -> new RowType.Field(Optional.of(field.getKey()), field.getValue().getTrinoType())) + .collect(toImmutableList())); + + return Optional.of(new CassandraType(UDT, trinoType, argumentTypes.buildOrThrow().values().stream().collect(toImmutableList()))); + } + + public NullableValue getColumnValue(CassandraType cassandraType, Row row, int position) + { + return getColumnValue(cassandraType, row, position, () -> row.getColumnDefinitions().get(position).getType()); + } + + public NullableValue getColumnValue(CassandraType cassandraType, GettableByIndex row, int position, Supplier dataTypeSupplier) + { + Type trinoType = cassandraType.getTrinoType(); + if (row.isNull(position)) { + return NullableValue.asNull(trinoType); + } + + switch (cassandraType.getKind()) { + case ASCII: + case TEXT: + case VARCHAR: + return NullableValue.of(trinoType, utf8Slice(row.getString(position))); + case INT: + return NullableValue.of(trinoType, (long) row.getInt(position)); + case SMALLINT: + return NullableValue.of(trinoType, (long) row.getShort(position)); + case TINYINT: + return NullableValue.of(trinoType, (long) row.getByte(position)); + case BIGINT: + case COUNTER: + return NullableValue.of(trinoType, row.getLong(position)); + case BOOLEAN: + return NullableValue.of(trinoType, row.getBoolean(position)); + case DOUBLE: + return NullableValue.of(trinoType, row.getDouble(position)); + case FLOAT: + return NullableValue.of(trinoType, (long) floatToRawIntBits(row.getFloat(position))); + case DECIMAL: + return NullableValue.of(trinoType, row.getBigDecimal(position).doubleValue()); + case UUID: + case TIMEUUID: + return NullableValue.of(trinoType, javaUuidToTrinoUuid(row.getUuid(position))); + case TIMESTAMP: + return NullableValue.of(trinoType, packDateTimeWithZone(row.getInstant(position).toEpochMilli(), TimeZoneKey.UTC_KEY)); + case DATE: + return NullableValue.of(trinoType, row.getLocalDate(position).toEpochDay()); + case INET: + return NullableValue.of(trinoType, castFromVarcharToIpAddress(utf8Slice(toAddrString(row.getInetAddress(position))))); + case VARINT: + return NullableValue.of(trinoType, utf8Slice(row.getBigInteger(position).toString())); + case BLOB: + case CUSTOM: + return NullableValue.of(trinoType, wrappedBuffer(row.getBytesUnsafe(position))); + case SET: + return NullableValue.of(trinoType, utf8Slice(buildArrayValueFromSetType(row, position, dataTypeSupplier.get()))); + case LIST: + return NullableValue.of(trinoType, utf8Slice(buildArrayValueFromListType(row, position, dataTypeSupplier.get()))); + case MAP: + return NullableValue.of(trinoType, utf8Slice(buildMapValue(row, position, dataTypeSupplier.get()))); + case TUPLE: + return NullableValue.of(trinoType, buildTupleValue(cassandraType, row, position)); + case UDT: + return NullableValue.of(trinoType, buildUserTypeValue(cassandraType, row, position)); + } + throw new IllegalStateException("Handling of type " + this + " is not implemented"); + } + + private String buildMapValue(GettableByIndex row, int position, DataType dataType) + { + checkArgument(dataType instanceof MapType, "Expected to deal with an instance of %s class, got: %s", MapType.class, dataType); + MapType mapType = (MapType) dataType; + return buildMapValue((Map) row.getObject(position), mapType.getKeyType(), mapType.getValueType()); + } + + private String buildMapValue(Map cassandraMap, DataType keyType, DataType valueType) + { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + for (Map.Entry entry : cassandraMap.entrySet()) { + if (sb.length() > 1) { + sb.append(","); + } + sb.append(objectToJson(entry.getKey(), keyType)); + sb.append(":"); + sb.append(objectToJson(entry.getValue(), valueType)); + } + sb.append("}"); + return sb.toString(); + } + + private String buildArrayValueFromSetType(GettableByIndex row, int position, DataType type) + { + checkArgument(type instanceof SetType, "Expected to deal with an instance of %s class, got: %s", SetType.class, type); + SetType setType = (SetType) type; + return buildArrayValue((Collection) row.getObject(position), setType.getElementType()); + } + + private String buildArrayValueFromListType(GettableByIndex row, int position, DataType type) + { + checkArgument(type instanceof ListType, "Expected to deal with an instance of %s class, got: %s", ListType.class, type); + ListType listType = (ListType) type; + return buildArrayValue((Collection) row.getObject(position), listType.getElementType()); + } + + @VisibleForTesting + String buildArrayValue(Collection cassandraCollection, DataType elementType) + { + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (Object value : cassandraCollection) { + if (sb.length() > 1) { + sb.append(","); + } + sb.append(objectToJson(value, elementType)); + } + sb.append("]"); + return sb.toString(); + } + + private Block buildTupleValue(CassandraType type, GettableByIndex row, int position) + { + verify(type.getKind() == TUPLE, "Not a TUPLE type"); + TupleValue tupleValue = row.getTupleValue(position); + RowBlockBuilder blockBuilder = (RowBlockBuilder) type.getTrinoType().createBlockBuilder(null, 1); + SingleRowBlockWriter singleRowBlockWriter = blockBuilder.beginBlockEntry(); + int tuplePosition = 0; + for (CassandraType argumentType : type.getArgumentTypes()) { + int finalTuplePosition = tuplePosition; + NullableValue value = getColumnValue(argumentType, tupleValue, tuplePosition, () -> tupleValue.getType().getComponentTypes().get(finalTuplePosition)); + writeNativeValue(argumentType.getTrinoType(), singleRowBlockWriter, value.getValue()); + tuplePosition++; + } + // can I just return singleRowBlockWriter here? It extends AbstractSingleRowBlock and tests pass. + blockBuilder.closeEntry(); + return (Block) type.getTrinoType().getObject(blockBuilder, 0); + } + + private Block buildUserTypeValue(CassandraType type, GettableByIndex row, int position) + { + verify(type.getKind() == UDT, "Not a user defined type: %s", type.getKind()); + UdtValue udtValue = row.getUdtValue(position); + RowBlockBuilder blockBuilder = (RowBlockBuilder) type.getTrinoType().createBlockBuilder(null, 1); + SingleRowBlockWriter singleRowBlockWriter = blockBuilder.beginBlockEntry(); + int tuplePosition = 0; + List udtTypeFieldTypes = udtValue.getType().getFieldTypes(); + for (CassandraType argumentType : type.getArgumentTypes()) { + int finalTuplePosition = tuplePosition; + NullableValue value = getColumnValue(argumentType, udtValue, tuplePosition, () -> udtTypeFieldTypes.get(finalTuplePosition)); + writeNativeValue(argumentType.getTrinoType(), singleRowBlockWriter, value.getValue()); + tuplePosition++; + } + + blockBuilder.closeEntry(); + return (Block) type.getTrinoType().getObject(blockBuilder, 0); + } + + // TODO unify with toCqlLiteral + public String getColumnValueForCql(CassandraType type, Row row, int position) + { + if (row.isNull(position)) { + return null; + } + + switch (type.getKind()) { + case ASCII: + case TEXT: + case VARCHAR: + return quoteStringLiteral(row.getString(position)); + case INT: + return Integer.toString(row.getInt(position)); + case SMALLINT: + return Short.toString(row.getShort(position)); + case TINYINT: + return Byte.toString(row.getByte(position)); + case BIGINT: + case COUNTER: + return Long.toString(row.getLong(position)); + case BOOLEAN: + return Boolean.toString(row.getBool(position)); + case DOUBLE: + return Double.toString(row.getDouble(position)); + case FLOAT: + return Float.toString(row.getFloat(position)); + case DECIMAL: + return row.getBigDecimal(position).toString(); + case UUID: + case TIMEUUID: + return row.getUuid(position).toString(); + case TIMESTAMP: + return Long.toString(row.getInstant(position).toEpochMilli()); + case DATE: + return quoteStringLiteral(row.getLocalDate(position).toString()); + case INET: + return quoteStringLiteral(toAddrString(row.getInetAddress(position))); + case VARINT: + return row.getBigInteger(position).toString(); + case BLOB: + case CUSTOM: + return Bytes.toHexString(row.getBytesUnsafe(position)); + + case LIST: + case SET: + case MAP: + case TUPLE: + case UDT: + // unsupported + break; + } + throw new IllegalStateException("Handling of type " + this + " is not implemented"); + } + + // TODO unify with getColumnValueForCql + public String toCqlLiteral(CassandraType type, Object trinoNativeValue) + { + CassandraType.Kind kind = type.getKind(); + if (kind == DATE) { + LocalDate date = LocalDate.ofEpochDay(toIntExact((long) trinoNativeValue)); + return quoteStringLiteral(date.toString()); + } + if (kind == TIMESTAMP) { + return String.valueOf(unpackMillisUtc((Long) trinoNativeValue)); + } + + String value; + if (trinoNativeValue instanceof Slice) { + value = ((Slice) trinoNativeValue).toStringUtf8(); + } + else { + value = trinoNativeValue.toString(); + } + + switch (kind) { + case ASCII: + case TEXT: + case VARCHAR: + return quoteStringLiteral(value); + case INET: + // remove '/' in the string. e.g. /127.0.0.1 + return quoteStringLiteral(value.substring(1)); + default: + return value; + } + } + + private String objectToJson(Object cassandraValue, DataType dataType) + { + CassandraType cassandraType = toCassandraType(dataType) + .orElseThrow(() -> new IllegalStateException("Unsupported type: " + dataType)); + + switch (cassandraType.getKind()) { + case ASCII: + case TEXT: + case VARCHAR: + case UUID: + case TIMEUUID: + case TIMESTAMP: + case DATE: + case INET: + case VARINT: + case TUPLE: + case UDT: + return quoteStringLiteralForJson(cassandraValue.toString()); + + case BLOB: + case CUSTOM: + return quoteStringLiteralForJson(Bytes.toHexString((ByteBuffer) cassandraValue)); + + case SMALLINT: + case TINYINT: + case INT: + case BIGINT: + case COUNTER: + case BOOLEAN: + case DOUBLE: + case FLOAT: + case DECIMAL: + return cassandraValue.toString(); + case LIST: + checkArgument(dataType instanceof ListType, "Expected to deal with an instance of %s class, got: %s", ListType.class, dataType); + ListType listType = (ListType) dataType; + return buildArrayValue((Collection) cassandraValue, listType.getElementType()); + case SET: + checkArgument(dataType instanceof SetType, "Expected to deal with an instance of %s class, got: %s", SetType.class, dataType); + SetType setType = (SetType) dataType; + return buildArrayValue((Collection) cassandraValue, setType.getElementType()); + case MAP: + checkArgument(dataType instanceof MapType, "Expected to deal with an instance of %s class, got: %s", MapType.class, dataType); + MapType mapType = (MapType) dataType; + return buildMapValue((Map) cassandraValue, mapType.getKeyType(), mapType.getValueType()); + } + throw new IllegalStateException("Unsupported type: " + cassandraType); + } + + public Object getJavaValue(CassandraType.Kind kind, Object trinoNativeValue) + { + switch (kind) { + case ASCII: + case TEXT: + case VARCHAR: + return ((Slice) trinoNativeValue).toStringUtf8(); + case BIGINT: + case BOOLEAN: + case DOUBLE: + case COUNTER: + return trinoNativeValue; + case INET: + try { + return InetAddress.getByAddress(((Slice) trinoNativeValue).getBytes()); + } + catch (UnknownHostException e) { + throw new TrinoException(INVALID_CAST_ARGUMENT, "Invalid IP address binary length: " + ((Slice) trinoNativeValue).length(), e); + } + case INT: + case SMALLINT: + case TINYINT: + return ((Long) trinoNativeValue).intValue(); + case FLOAT: + // conversion can result in precision lost + return intBitsToFloat(((Long) trinoNativeValue).intValue()); + case DECIMAL: + // conversion can result in precision lost + // Trino uses double for decimal, so to keep the floating point precision, convert it to string. + // Otherwise partition id doesn't match + return new BigDecimal(trinoNativeValue.toString()); + case TIMESTAMP: + return Instant.ofEpochMilli(unpackMillisUtc((Long) trinoNativeValue)); + case DATE: + return LocalDate.ofEpochDay(((Long) trinoNativeValue).intValue()); + case UUID: + case TIMEUUID: + return trinoUuidToJavaUuid((Slice) trinoNativeValue); + case BLOB: + case CUSTOM: + case TUPLE: + case UDT: + return ((Slice) trinoNativeValue).toStringUtf8(); + case VARINT: + return new BigInteger(((Slice) trinoNativeValue).toStringUtf8()); + case SET: + case LIST: + case MAP: + } + throw new IllegalStateException("Back conversion not implemented for " + this); + } + + public boolean isSupportedPartitionKey(CassandraType.Kind kind) + { + switch (kind) { + case ASCII: + case TEXT: + case VARCHAR: + case BIGINT: + case BOOLEAN: + case DOUBLE: + case INET: + case INT: + case TINYINT: + case SMALLINT: + case FLOAT: + case DECIMAL: + case DATE: + case TIMESTAMP: + case UUID: + case TIMEUUID: + return true; + case COUNTER: + case BLOB: + case CUSTOM: + case VARINT: + case SET: + case LIST: + case MAP: + case TUPLE: + case UDT: + default: + return false; + } + } + + public boolean isFullySupported(DataType dataType) + { + if (toCassandraType(dataType).isEmpty()) { + return false; + } + + if (dataType instanceof UserDefinedType) { + return ((UserDefinedType) dataType).getFieldTypes().stream() + .allMatch(fieldType -> isFullySupported(fieldType)); + } + + if (dataType instanceof MapType) { + MapType mapType = (MapType) dataType; + return Arrays.stream(new DataType[] {mapType.getKeyType(), mapType.getValueType()}) + .allMatch(type -> isFullySupported(type)); + } + + if (dataType instanceof ListType) { + return isFullySupported(((ListType) dataType).getElementType()); + } + + if (dataType instanceof TupleType) { + return ((TupleType) dataType).getComponentTypes().stream() + .allMatch(componentType -> isFullySupported(componentType)); + } + + if (dataType instanceof SetType) { + return isFullySupported(((SetType) dataType).getElementType()); + } + + return true; + } + + public CassandraType toCassandraType(Type type, ProtocolVersion protocolVersion) + { + if (type.equals(BooleanType.BOOLEAN)) { + return CassandraTypes.BOOLEAN; + } + if (type.equals(BigintType.BIGINT)) { + return CassandraTypes.BIGINT; + } + if (type.equals(IntegerType.INTEGER)) { + return CassandraTypes.INT; + } + if (type.equals(SmallintType.SMALLINT)) { + return CassandraTypes.SMALLINT; + } + if (type.equals(TinyintType.TINYINT)) { + return CassandraTypes.TINYINT; + } + if (type.equals(DoubleType.DOUBLE)) { + return CassandraTypes.DOUBLE; + } + if (type.equals(RealType.REAL)) { + return CassandraTypes.FLOAT; + } + if (type instanceof VarcharType) { + return CassandraTypes.TEXT; + } + if (type.equals(DateType.DATE)) { + return protocolVersion.getCode() <= ProtocolVersion.V3.getCode() + ? CassandraTypes.TEXT + : CassandraTypes.DATE; + } + if (type.equals(VarbinaryType.VARBINARY)) { + return CassandraTypes.BLOB; + } + if (type.equals(TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS)) { + return CassandraTypes.TIMESTAMP; + } + if (type.equals(UuidType.UUID)) { + return CassandraTypes.UUID; + } + if (type.equals(ipAddressType)) { + return new CassandraType( + CassandraType.Kind.INET, + ipAddressType); + } + throw new TrinoException(NOT_SUPPORTED, "Unsupported type: " + type); + } + + public boolean isIpAddressType(Type type) + { + return type.equals(ipAddressType); + } + + // This is a copy of IpAddressOperators.castFromVarcharToIpAddress method + private static Slice castFromVarcharToIpAddress(Slice slice) + { + byte[] address; + try { + address = InetAddresses.forString(slice.toStringUtf8()).getAddress(); + } + catch (IllegalArgumentException e) { + throw new TrinoException(INVALID_CAST_ARGUMENT, "Cannot cast value to IPADDRESS: " + slice.toStringUtf8()); + } + + byte[] bytes; + if (address.length == 4) { + bytes = new byte[16]; + bytes[10] = (byte) 0xff; + bytes[11] = (byte) 0xff; + arraycopy(address, 0, bytes, 12, 4); + } + else if (address.length == 16) { + bytes = address; + } + else { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Invalid InetAddress length: " + address.length); + } + + return wrappedBuffer(bytes); + } +} diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java index f1cc33f78390..564d9f3d7a17 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraServer.java @@ -40,6 +40,7 @@ import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.REQUEST_TIMEOUT; import static com.google.common.io.Files.write; import static com.google.common.io.Resources.getResource; +import static io.trino.plugin.cassandra.CassandraTestingUtils.CASSANDRA_TYPE_MANAGER; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.file.Files.createDirectory; @@ -93,6 +94,7 @@ public CassandraServer(String cassandraVersion) .withConfigLoader(driverConfigLoaderBuilder.build()); CassandraSession session = new CassandraSession( + CASSANDRA_TYPE_MANAGER, JsonCodec.listJsonCodec(ExtraColumnMetadata.class), cqlSessionBuilder::build, new Duration(1, MINUTES)); diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java index c33d56c708e5..fbad4da7ed52 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/CassandraTestingUtils.java @@ -32,11 +32,14 @@ import java.util.UUID; import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static java.lang.String.format; import static org.testng.Assert.assertEquals; public final class CassandraTestingUtils { + public static final CassandraTypeManager CASSANDRA_TYPE_MANAGER = new CassandraTypeManager(TESTING_TYPE_MANAGER); + public static final String TABLE_ALL_TYPES = "table_all_types"; public static final String TABLE_TUPLE_TYPE = "table_tuple_type"; public static final String TABLE_USER_DEFINED_TYPE = "table_user_defined_type"; diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java index b28042db4c07..f8b79b147a77 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnector.java @@ -16,6 +16,7 @@ import com.datastax.oss.protocol.internal.util.Bytes; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import io.trino.spi.block.Block; @@ -47,10 +48,13 @@ import io.trino.spi.type.VarcharType; import io.trino.testing.TestingConnectorContext; import io.trino.testing.TestingConnectorSession; +import io.trino.type.IpAddressType; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.util.Date; import java.util.List; import java.util.Map; @@ -299,6 +303,7 @@ public void testGetTupleType() @Test public void testGetUserDefinedType() + throws UnknownHostException { ConnectorTableHandle tableHandle = getTableHandle(tableUdt); ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(SESSION, tableHandle); @@ -342,7 +347,7 @@ public void testGetUserDefinedType() assertEquals(DOUBLE.getDouble(udtValue, 8), 99999999999999997748809823456034029568D); assertEquals(DOUBLE.getDouble(udtValue, 9), 4.9407e-324); assertEquals(REAL.getObjectValue(SESSION, udtValue, 10), 1.4E-45f); - assertEquals(VARCHAR.getSlice(udtValue, 11).toStringUtf8(), "0.0.0.0"); + assertEquals(InetAddresses.toAddrString(InetAddress.getByAddress(IpAddressType.IPADDRESS.getSlice(udtValue, 11).getBytes())), "0.0.0.0"); assertEquals(VARCHAR.getSlice(udtValue, 12).toStringUtf8(), "varchar"); assertEquals(VARCHAR.getSlice(udtValue, 13).toStringUtf8(), "-9223372036854775808"); assertEquals(trinoUuidToJavaUuid(UUID.getSlice(udtValue, 14)).toString(), "d2177dd0-eaa2-11de-a572-001b779c76e3"); @@ -413,6 +418,9 @@ else if (type instanceof RowType) { else if (UuidType.UUID.equals(type)) { cursor.getSlice(columnIndex); } + else if (IpAddressType.IPADDRESS.equals(type)) { + cursor.getSlice(columnIndex); + } else { fail("Unknown primitive type " + type + " for column " + columnIndex); } diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java index bd0ca65c3092..01998d779cd0 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java @@ -18,7 +18,6 @@ import com.google.common.primitives.Ints; import io.airlift.units.Duration; import io.trino.Session; -import io.trino.spi.type.Type; import io.trino.testing.BaseConnectorTest; import io.trino.testing.Bytes; import io.trino.testing.MaterializedResult; @@ -59,11 +58,11 @@ import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; -import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.testing.MaterializedResult.DEFAULT_PRECISION; import static io.trino.testing.MaterializedResult.resultBuilder; import static io.trino.testing.QueryAssertions.assertContains; import static io.trino.testing.QueryAssertions.assertContainsEventually; +import static io.trino.type.IpAddressType.IPADDRESS; import static java.lang.String.format; import static java.util.Comparator.comparing; import static java.util.concurrent.TimeUnit.MINUTES; @@ -324,7 +323,7 @@ public void testPushdownAllTypesPartitionKeyPredicate() " AND typeboolean = false" + " AND typedouble = 16384.0" + " AND typefloat = REAL '2097152.0'" + - " AND typeinet = '127.0.0.1'" + + " AND typeinet = IPADDRESS '127.0.0.1'" + " AND typevarchar = 'varchar 7'" + " AND typetimeuuid = UUID 'd2177dd0-eaa2-11de-a572-001b779c76e7'" + ""; @@ -412,7 +411,7 @@ public void testPartitionKeyPredicate() " AND typedecimal = 128.0" + " AND typedouble = 16384.0" + " AND typefloat = REAL '2097152.0'" + - " AND typeinet = '127.0.0.1'" + + " AND typeinet = IPADDRESS '127.0.0.1'" + " AND typevarchar = 'varchar 7'" + " AND typevarint = '10000000'" + " AND typetimeuuid = UUID 'd2177dd0-eaa2-11de-a572-001b779c76e7'" + @@ -492,7 +491,7 @@ public void testSelect() rowNumber -> format("['list-value-1%d', 'list-value-2%d']", rowNumber, rowNumber), rowNumber -> format("{%d:%d, %d:%d}", rowNumber, rowNumber + 1, rowNumber + 2, rowNumber + 3), rowNumber -> format("{false, true}"))))) { - assertSelect(testCassandraTable.getTableName(), false); + assertSelect(testCassandraTable.getTableName()); } try (TestCassandraTable testCassandraTable = testTable( @@ -541,7 +540,7 @@ public void testSelect() rowNumber -> format("['list-value-1%d', 'list-value-2%d']", rowNumber, rowNumber), rowNumber -> format("{%d:%d, %d:%d}", rowNumber, rowNumber + 1, rowNumber + 2, rowNumber + 3), rowNumber -> format("{false, true}"))))) { - assertSelect(testCassandraTable.getTableName(), false); + assertSelect(testCassandraTable.getTableName()); } } @@ -606,7 +605,7 @@ public void testCreateTableAs() rowNumber -> format("{false, true}"))))) { execute("DROP TABLE IF EXISTS table_all_types_copy"); execute("CREATE TABLE table_all_types_copy AS SELECT * FROM " + testCassandraTable.getTableName()); - assertSelect("table_all_types_copy", true); + assertSelect("table_all_types_copy"); execute("DROP TABLE table_all_types_copy"); } } @@ -1105,7 +1104,7 @@ public void testAllTypesInsert() assertEquals(execute(sql).getRowCount(), 0); // TODO Following types are not supported now. We need to change null into the value after fixing it - // blob, frozen>, inet, list, map, set, decimal, varint + // blob, frozen>, list, map, set, decimal, varint // timestamp can be inserted but the expected and actual values are not same execute("INSERT INTO " + testCassandraTable.getTableName() + " (" + "key," + @@ -1138,7 +1137,7 @@ public void testAllTypesInsert() "null, " + "0.3, " + "cast('0.4' as real), " + - "null, " + + "IPADDRESS '10.10.10.1', " + "'varchar1', " + "null, " + "UUID '50554d6e-29bb-11e5-b345-feff819cdc9f', " + @@ -1162,7 +1161,7 @@ public void testAllTypesInsert() null, 0.3, (float) 0.4, - null, + "10.10.10.1", "varchar1", null, java.util.UUID.fromString("50554d6e-29bb-11e5-b345-feff819cdc9f"), @@ -1334,10 +1333,8 @@ protected void verifyTableNameLengthFailurePermissible(Throwable e) assertThat(e).hasMessageContaining("Table names shouldn't be more than 48 characters long"); } - private void assertSelect(String tableName, boolean createdByTrino) + private void assertSelect(String tableName) { - Type inetType = createdByTrino ? createUnboundedVarcharType() : createVarcharType(45); - String sql = "SELECT " + " key, " + " typeuuid, " + @@ -1375,7 +1372,7 @@ private void assertSelect(String tableName, boolean createdByTrino) DOUBLE, DOUBLE, REAL, - inetType, + IPADDRESS, createUnboundedVarcharType(), createUnboundedVarcharType(), UUID, diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeManager.java similarity index 59% rename from plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java rename to plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeManager.java index 099c3aad9946..dc89524abbf5 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraType.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeManager.java @@ -22,20 +22,21 @@ import java.io.IOException; +import static io.trino.plugin.cassandra.CassandraTestingUtils.CASSANDRA_TYPE_MANAGER; import static org.testng.Assert.assertTrue; -public class TestCassandraType +public class TestCassandraTypeManager { @Test public void testJsonArrayEncoding() { - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList("one", "two", "three\""), DataTypes.TEXT))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(1, 2, 3), DataTypes.INT))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(100000L, 200000000L, 3000000000L), DataTypes.BIGINT))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList(1.0, 2.0, 3.0), DataTypes.DOUBLE))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList((short) -32768, (short) 0, (short) 32767), DataTypes.SMALLINT))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList((byte) -128, (byte) 0, (byte) 127), DataTypes.TINYINT))); - assertTrue(isValidJson(CassandraType.buildArrayValue(Lists.newArrayList("1970-01-01", "5555-06-15", "9999-12-31"), DataTypes.DATE))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList("one", "two", "three\""), DataTypes.TEXT))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList(1, 2, 3), DataTypes.INT))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList(100000L, 200000000L, 3000000000L), DataTypes.BIGINT))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList(1.0, 2.0, 3.0), DataTypes.DOUBLE))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList((short) -32768, (short) 0, (short) 32767), DataTypes.SMALLINT))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList((byte) -128, (byte) 0, (byte) 127), DataTypes.TINYINT))); + assertTrue(isValidJson(CASSANDRA_TYPE_MANAGER.buildArrayValue(Lists.newArrayList("1970-01-01", "5555-06-15", "9999-12-31"), DataTypes.DATE))); } private static void continueWhileNotNull(JsonParser parser, JsonToken token) diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java index 480233200942..014729a445c9 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraTypeMapping.java @@ -61,7 +61,7 @@ import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.type.IpAddressType.IPADDRESS; import static java.lang.String.format; import static java.time.ZoneOffset.UTC; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -392,17 +392,33 @@ public void testCassandraMap() public void testCassandraInet() { SqlDataTypeTest.create() - .addRoundTrip("inet", "NULL", createVarcharType(45), "CAST(NULL AS varchar(45))") - .addRoundTrip("inet", "'0.0.0.0'", createVarcharType(45), "CAST('0.0.0.0' AS varchar(45))") - .addRoundTrip("inet", "'116.253.40.133'", createVarcharType(45), "CAST('116.253.40.133' AS varchar(45))") - .addRoundTrip("inet", "'255.255.255.255'", createVarcharType(45), "CAST('255.255.255.255' AS varchar(45))") - .addRoundTrip("inet", "'::'", createVarcharType(45), "CAST('::' AS varchar(45))") - .addRoundTrip("inet", "'2001:44c8:129:2632:33:0:252:2'", createVarcharType(45), "CAST('2001:44c8:129:2632:33:0:252:2' AS varchar(45))") - .addRoundTrip("inet", "'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", createVarcharType(45), "CAST('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff' AS varchar(45))") - .addRoundTrip("inet", "'ffff:ffff:ffff:ffff:ffff:ffff:255.255.255.255'", createVarcharType(45), "CAST('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff' AS varchar(45))") + .addRoundTrip("inet", "NULL", IPADDRESS, "CAST(NULL AS ipaddress)") + .addRoundTrip("inet", "'0.0.0.0'", IPADDRESS, "CAST('0.0.0.0' AS ipaddress)") + .addRoundTrip("inet", "'116.253.40.133'", IPADDRESS, "CAST('116.253.40.133' AS ipaddress)") + .addRoundTrip("inet", "'255.255.255.255'", IPADDRESS, "CAST('255.255.255.255' AS ipaddress)") + .addRoundTrip("inet", "'::'", IPADDRESS, "CAST('::' AS ipaddress)") + .addRoundTrip("inet", "'2001:44c8:129:2632:33:0:252:2'", IPADDRESS, "CAST('2001:44c8:129:2632:33:0:252:2' AS ipaddress)") + .addRoundTrip("inet", "'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", IPADDRESS, "CAST('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff' AS ipaddress)") + .addRoundTrip("inet", "'ffff:ffff:ffff:ffff:ffff:ffff:255.255.255.255'", IPADDRESS, "CAST('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff' AS ipaddress)") .execute(getQueryRunner(), cassandraCreateAndInsert("tpch.test_inet")); } + @Test + public void testIpAddress() + { + SqlDataTypeTest.create() + .addRoundTrip("ipaddress", "NULL", IPADDRESS, "CAST(NULL AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress '0.0.0.0'", IPADDRESS, "CAST('0.0.0.0' AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress '116.253.40.133'", IPADDRESS, "CAST('116.253.40.133' AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress '255.255.255.255'", IPADDRESS, "CAST('255.255.255.255' AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress '::'", IPADDRESS, "CAST('::' AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress '2001:44c8:129:2632:33:0:252:2'", IPADDRESS, "CAST('2001:44c8:129:2632:33:0:252:2' AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'", IPADDRESS, "CAST('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff' AS ipaddress)") + .addRoundTrip("ipaddress", "ipaddress 'ffff:ffff:ffff:ffff:ffff:ffff:255.255.255.255'", IPADDRESS, "CAST('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff' AS ipaddress)") + .execute(getQueryRunner(), trinoCreateAndInsert("test_ipaddress")) + .execute(getQueryRunner(), trinoCreateAsSelect("test_ipaddress")); + } + @Test public void testCassandraVarint() { diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java index 73da7463bcec..75e8ccbf4ad4 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestingScyllaServer.java @@ -32,6 +32,7 @@ import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.METADATA_SCHEMA_REFRESHED_KEYSPACES; import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.PROTOCOL_VERSION; import static com.datastax.oss.driver.api.core.config.DefaultDriverOption.REQUEST_TIMEOUT; +import static io.trino.plugin.cassandra.CassandraTestingUtils.CASSANDRA_TYPE_MANAGER; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; @@ -78,6 +79,7 @@ public TestingScyllaServer(String version) .withConfigLoader(config.build()); session = new CassandraSession( + CASSANDRA_TYPE_MANAGER, JsonCodec.listJsonCodec(ExtraColumnMetadata.class), cqlSessionBuilder::build, new Duration(1, MINUTES)); diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java index a986a5fa51dd..c7d2514813fa 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/util/TestCassandraClusteringPredicatesExtractor.java @@ -27,6 +27,7 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; +import static io.trino.plugin.cassandra.CassandraTestingUtils.CASSANDRA_TYPE_MANAGER; import static io.trino.spi.type.BigintType.BIGINT; import static org.testng.Assert.assertEquals; @@ -61,7 +62,7 @@ public void testBuildClusteringPredicate() col1, Domain.singleValue(BIGINT, 23L), col2, Domain.singleValue(BIGINT, 34L), col4, Domain.singleValue(BIGINT, 26L))); - CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); + CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(CASSANDRA_TYPE_MANAGER, cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); String predicate = predicatesExtractor.getClusteringKeyPredicates(); assertEquals(predicate, "\"clusteringKey1\" = 34"); } @@ -73,7 +74,7 @@ public void testGetUnenforcedPredicates() ImmutableMap.of( col2, Domain.singleValue(BIGINT, 34L), col4, Domain.singleValue(BIGINT, 26L))); - CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); + CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(CASSANDRA_TYPE_MANAGER, cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); TupleDomain unenforcedPredicates = TupleDomain.withColumnDomains(ImmutableMap.of(col4, Domain.singleValue(BIGINT, 26L))); assertEquals(predicatesExtractor.getUnenforcedConstraints(), unenforcedPredicates); } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestInsertIntoCassandraTable.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestInsertIntoCassandraTable.java index 109cbfabd2e9..fc14500a7f3c 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestInsertIntoCassandraTable.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestInsertIntoCassandraTable.java @@ -74,7 +74,7 @@ public void testInsertIntoValuesToCassandraTableAllSimpleTypes() assertThat(queryResult).hasNoRows(); // TODO Following types are not supported now. We need to change null into the value after fixing it - // blob, frozen>, inet, list, map, set, decimal, varint + // blob, frozen>, list, map, set, decimal, varint onTrino().executeQuery("INSERT INTO " + tableNameInDatabase + "(a, b, bl, bo, d, do, dt, f, fr, i, ti, si, integer, l, m, s, t, ts, tu, u, v, vari) VALUES (" + "'ascii value', " + @@ -86,7 +86,7 @@ public void testInsertIntoValuesToCassandraTableAllSimpleTypes() "DATE '9999-12-31'," + "REAL '123.45678', " + "null, " + - "null, " + + "IPADDRESS '0.0.0.0', " + "TINYINT '-128', " + "SMALLINT '-32768', " + "123, " + @@ -111,7 +111,7 @@ public void testInsertIntoValuesToCassandraTableAllSimpleTypes() Date.valueOf("9999-12-31"), 123.45678, null, - null, + "0.0.0.0", 123, null, null, diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestSelect.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestSelect.java index fcb28bb80627..4364850062e4 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestSelect.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/cassandra/TestSelect.java @@ -185,7 +185,7 @@ public void testAllDataTypes() CONNECTOR_NAME, KEY_SPACE, CASSANDRA_ALL_TYPES.getName())); assertThat(query) - .hasColumns(VARCHAR, BIGINT, VARBINARY, BOOLEAN, DOUBLE, DOUBLE, DATE, REAL, VARCHAR, VARCHAR, + .hasColumns(VARCHAR, BIGINT, VARBINARY, BOOLEAN, DOUBLE, DOUBLE, DATE, REAL, VARCHAR, JAVA_OBJECT, INTEGER, VARCHAR, VARCHAR, VARCHAR, SMALLINT, VARCHAR, TINYINT, TIMESTAMP_WITH_TIMEZONE, JAVA_OBJECT, JAVA_OBJECT, VARCHAR, VARCHAR) .containsOnly( @@ -299,7 +299,7 @@ public void testSelectAllTypePartitioningMaterializedView() CONNECTOR_NAME, KEY_SPACE, materializedViewName)); assertThat(query) - .hasColumns(VARCHAR, BIGINT, VARBINARY, BOOLEAN, DOUBLE, DOUBLE, DATE, REAL, VARCHAR, VARCHAR, + .hasColumns(VARCHAR, BIGINT, VARBINARY, BOOLEAN, DOUBLE, DOUBLE, DATE, REAL, VARCHAR, JAVA_OBJECT, INTEGER, VARCHAR, VARCHAR, VARCHAR, SMALLINT, VARCHAR, TINYINT, TIMESTAMP_WITH_TIMEZONE, JAVA_OBJECT, JAVA_OBJECT, VARCHAR, VARCHAR) .containsOnly(