diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueExpressionUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueExpressionUtil.java index 152af4f7131d..798791e7a617 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueExpressionUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueExpressionUtil.java @@ -34,6 +34,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -151,16 +152,18 @@ static Optional buildGlueExpressionForSingleDomain(String columnName, Do // for column <> '__HIVE_DEFAULT_PARTITION__' or column = '__HIVE_DEFAULT_PARTITION__' expression on numeric types // "IS NULL" operator in the official documentation always returns empty result regardless of the type. // https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions - if ((domain.getValues().isAll() || domain.getValues().isNone()) && !isQuotedType(domain.getType())) { + if ((domain.getValues().isAll() || domain.isNullAllowed()) && !isQuotedType(domain.getType())) { return Optional.empty(); } if (domain.getValues().isAll()) { + verify(!domain.isNullAllowed(), "Unexpected domain: %s", domain); return Optional.of(format("(%s <> '%s')", columnName, NULL_STRING)); } - // null must be allowed for this case since callers must filter Domain.none() out if (domain.getValues().isNone()) { + // null must be allowed for this case since callers must filter Domain.none() out + verify(domain.isNullAllowed(), "Unexpected domain: %s", domain); return Optional.of(format("(%s = '%s')", columnName, NULL_STRING)); } @@ -204,6 +207,10 @@ else if (singleValues.size() > 1) { disjuncts.add(inClause); } + if (domain.isNullAllowed()) { + disjuncts.add(format("(%s = '%s')", columnName, NULL_STRING)); + } + return Optional.of("(" + DISJUNCT_JOINER.join(disjuncts) + ")"); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java index 65e78229fa7d..f184294c739b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java @@ -48,6 +48,7 @@ import io.trino.plugin.hive.metastore.Table; import io.trino.plugin.hive.metastore.cache.CachingHiveMetastoreConfig; import io.trino.plugin.hive.metastore.thrift.BridgingHiveMetastore; +import io.trino.plugin.hive.metastore.thrift.ThriftMetastoreConfig; import io.trino.plugin.hive.orc.OrcPageSource; import io.trino.plugin.hive.parquet.ParquetPageSource; import io.trino.plugin.hive.rcfile.RcFilePageSource; @@ -615,6 +616,7 @@ private static RowType toRowType(List columns) protected String database; protected SchemaTableName tablePartitionFormat; protected SchemaTableName tableUnpartitioned; + protected SchemaTableName tablePartitionedWithNull; protected SchemaTableName tableOffline; protected SchemaTableName tableOfflinePartition; protected SchemaTableName tableNotReadable; @@ -634,6 +636,8 @@ private static RowType toRowType(List columns) protected ColumnHandle dummyColumn; protected ColumnHandle intColumn; protected ColumnHandle invalidColumnHandle; + protected ColumnHandle pStringColumn; + protected ColumnHandle pIntegerColumn; protected ConnectorTableProperties tablePartitionFormatProperties; protected ConnectorTableProperties tableUnpartitionedProperties; @@ -692,6 +696,7 @@ protected void setupHive(String databaseName) database = databaseName; tablePartitionFormat = new SchemaTableName(database, "trino_test_partition_format"); tableUnpartitioned = new SchemaTableName(database, "trino_test_unpartitioned"); + tablePartitionedWithNull = new SchemaTableName(database, "trino_test_partitioned_with_null"); tableOffline = new SchemaTableName(database, "trino_test_offline"); tableOfflinePartition = new SchemaTableName(database, "trino_test_offline_partition"); tableNotReadable = new SchemaTableName(database, "trino_test_not_readable"); @@ -711,6 +716,8 @@ protected void setupHive(String databaseName) dummyColumn = createBaseColumn("dummy", -1, HIVE_INT, INTEGER, PARTITION_KEY, Optional.empty()); intColumn = createBaseColumn("t_int", -1, HIVE_INT, INTEGER, PARTITION_KEY, Optional.empty()); invalidColumnHandle = createBaseColumn(INVALID_COLUMN, 0, HIVE_STRING, VARCHAR, REGULAR, Optional.empty()); + pStringColumn = createBaseColumn("p_string", -1, HIVE_STRING, VARCHAR, PARTITION_KEY, Optional.empty()); + pIntegerColumn = createBaseColumn("p_integer", -1, HIVE_INT, INTEGER, PARTITION_KEY, Optional.empty()); List partitionColumns = ImmutableList.of(dsColumn, fileFormatColumn, dummyColumn); tablePartitionFormatPartitions = ImmutableList.builder() @@ -783,6 +790,8 @@ protected final void setup(String host, int port, String databaseName, String ti new BridgingHiveMetastore(testingThriftHiveMetastoreBuilder() .metastoreClient(HostAndPort.fromParts(host, port)) .hiveConfig(hiveConfig) + .thriftMetastoreConfig(new ThriftMetastoreConfig() + .setAssumeCanonicalPartitionKeys(true)) .hdfsEnvironment(hdfsEnvironment) .build()), executor, @@ -1114,6 +1123,80 @@ public void testGetPartitionsWithBindings() } } + @Test + public void testGetPartitionsWithFilter() + { + try (Transaction transaction = newTransaction()) { + ConnectorMetadata metadata = transaction.getMetadata(); + ConnectorTableHandle tableHandle = getTableHandle(metadata, tablePartitionedWithNull); + + Domain varcharSomeValue = Domain.singleValue(VARCHAR, utf8Slice("abc")); + Domain varcharOnlyNull = Domain.onlyNull(VARCHAR); + Domain varcharNotNull = Domain.notNull(VARCHAR); + + Domain integerSomeValue = Domain.singleValue(INTEGER, 123L); + Domain integerOnlyNull = Domain.onlyNull(INTEGER); + Domain integerNotNull = Domain.notNull(INTEGER); + + // all + assertThat(getPartitionNamesByFilter(metadata, tableHandle, new Constraint(TupleDomain.all()))) + .containsOnly( + "p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__", + "p_string=abc/p_integer=123", + "p_string=def/p_integer=456"); + + // is some value + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pStringColumn, varcharSomeValue)) + .containsOnly("p_string=abc/p_integer=123"); + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pIntegerColumn, integerSomeValue)) + .containsOnly("p_string=abc/p_integer=123"); + + // IS NULL + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pStringColumn, varcharOnlyNull)) + .containsOnly("p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__"); + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pIntegerColumn, integerOnlyNull)) + .containsOnly("p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__"); + + // IS NOT NULL + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pStringColumn, varcharNotNull)) + .containsOnly("p_string=abc/p_integer=123", "p_string=def/p_integer=456"); + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pIntegerColumn, integerNotNull)) + .containsOnly("p_string=abc/p_integer=123", "p_string=def/p_integer=456"); + + // IS NULL OR is some value + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pStringColumn, varcharOnlyNull.union(varcharSomeValue))) + .containsOnly("p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__", "p_string=abc/p_integer=123"); + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pIntegerColumn, integerOnlyNull.union(integerSomeValue))) + .containsOnly("p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__", "p_string=abc/p_integer=123"); + + // IS NOT NULL AND is NOT some value + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pStringColumn, varcharSomeValue.complement().intersect(varcharNotNull))) + .containsOnly("p_string=def/p_integer=456"); + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pIntegerColumn, integerSomeValue.complement().intersect(integerNotNull))) + .containsOnly("p_string=def/p_integer=456"); + + // IS NULL OR is NOT some value + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pStringColumn, varcharSomeValue.complement())) + .containsOnly("p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__", "p_string=def/p_integer=456"); + assertThat(getPartitionNamesByFilter(metadata, tableHandle, pIntegerColumn, integerSomeValue.complement())) + .containsOnly("p_string=__HIVE_DEFAULT_PARTITION__/p_integer=__HIVE_DEFAULT_PARTITION__", "p_string=def/p_integer=456"); + } + } + + private Set getPartitionNamesByFilter(ConnectorMetadata metadata, ConnectorTableHandle tableHandle, ColumnHandle columnHandle, Domain domain) + { + return getPartitionNamesByFilter(metadata, tableHandle, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(columnHandle, domain)))); + } + + private Set getPartitionNamesByFilter(ConnectorMetadata metadata, ConnectorTableHandle tableHandle, Constraint constraint) + { + return applyFilter(metadata, tableHandle, constraint) + .getPartitions().orElseThrow(() -> new IllegalStateException("No partitions")) + .stream() + .map(HivePartition::getPartitionId) + .collect(toImmutableSet()); + } + @Test public void testMismatchSchemaTable() throws Exception @@ -5090,10 +5173,11 @@ protected ConnectorTableHandle getTableHandle(ConnectorMetadata metadata, Schema return handle; } - private ConnectorTableHandle applyFilter(ConnectorMetadata metadata, ConnectorTableHandle tableHandle, Constraint constraint) + private HiveTableHandle applyFilter(ConnectorMetadata metadata, ConnectorTableHandle tableHandle, Constraint constraint) { return metadata.applyFilter(newSession(), tableHandle, constraint) .map(ConstraintApplicationResult::getHandle) + .map(HiveTableHandle.class::cast) .orElseThrow(AssertionError::new); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingThriftHiveMetastoreBuilder.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingThriftHiveMetastoreBuilder.java index 815889d6e4b3..07eb9e9b309d 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingThriftHiveMetastoreBuilder.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestingThriftHiveMetastoreBuilder.java @@ -54,6 +54,7 @@ public final class TestingThriftHiveMetastoreBuilder private MetastoreLocator metastoreLocator; private HiveConfig hiveConfig = new HiveConfig(); + private ThriftMetastoreConfig thriftMetastoreConfig = new ThriftMetastoreConfig(); private HdfsEnvironment hdfsEnvironment = HDFS_ENVIRONMENT; public static TestingThriftHiveMetastoreBuilder testingThriftHiveMetastoreBuilder() @@ -85,6 +86,12 @@ public TestingThriftHiveMetastoreBuilder hiveConfig(HiveConfig hiveConfig) return this; } + public TestingThriftHiveMetastoreBuilder thriftMetastoreConfig(ThriftMetastoreConfig thriftMetastoreConfig) + { + this.thriftMetastoreConfig = requireNonNull(thriftMetastoreConfig, "thriftMetastoreConfig is null"); + return this; + } + public TestingThriftHiveMetastoreBuilder hdfsEnvironment(HdfsEnvironment hdfsEnvironment) { this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); @@ -97,12 +104,12 @@ public ThriftMetastore build() ThriftHiveMetastoreFactory metastoreFactory = new ThriftHiveMetastoreFactory( new TokenDelegationThriftMetastoreFactory( metastoreLocator, - new ThriftMetastoreConfig(), + thriftMetastoreConfig, new ThriftMetastoreAuthenticationConfig(), hdfsEnvironment), new HiveMetastoreConfig().isHideDeltaLakeTables(), hiveConfig.isTranslateHiveViews(), - new ThriftMetastoreConfig(), + thriftMetastoreConfig, hdfsEnvironment); return metastoreFactory.createMetastore(Optional.empty()); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestGlueExpressionUtil.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestGlueExpressionUtil.java index e8b932cc2aea..70e7e9e1718b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestGlueExpressionUtil.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestGlueExpressionUtil.java @@ -144,6 +144,28 @@ public void testBuildGlueExpressionTupleDomainNotNull() assertEquals(expression, format("(col1 <> '%s')", GlueExpressionUtil.NULL_STRING)); } + @Test + public void testBuildGlueExpressionTupleDomainEqualsOrIsNull() + { + TupleDomain filter = new PartitionFilterBuilder() + .addStringValues("col1", "2020-01-01") + .addDomain("col1", Domain.onlyNull(VarcharType.VARCHAR)) + .build(); + String expression = buildGlueExpression(ImmutableList.of("col1"), filter, true); + assertEquals(expression, format("((col1 = '2020-01-01') OR (col1 = '%s'))", GlueExpressionUtil.NULL_STRING)); + } + + @Test + public void testBuildGlueExpressionTupleDomainEqualsAndIsNotNull() + { + TupleDomain filter = new PartitionFilterBuilder() + .addStringValues("col1", "2020-01-01") + .addDomain("col2", Domain.notNull(VarcharType.VARCHAR)) + .build(); + String expression = buildGlueExpression(ImmutableList.of("col1", "col2"), filter, true); + assertEquals(expression, format("((col1 = '2020-01-01')) AND (col2 <> '%s')", GlueExpressionUtil.NULL_STRING)); + } + @Test public void testBuildGlueExpressionMaxLengthNone() { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java index 8ef5ef7fb0e4..76a94c0f7840 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/glue/TestHiveGlueMetastore.java @@ -53,6 +53,7 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; import io.trino.spi.statistics.ColumnStatisticMetadata; import io.trino.spi.statistics.ComputedStatistics; import io.trino.spi.statistics.TableStatisticType; @@ -831,18 +832,51 @@ public void testGetPartitionsFilterIsNull() public void testGetPartitionsFilterIsNullWithValue() throws Exception { - TupleDomain isNullFilter = new PartitionFilterBuilder() - .addDomain(PARTITION_KEY, Domain.onlyNull(VarcharType.VARCHAR)) - .build(); List partitionList = new ArrayList<>(); partitionList.add("100"); partitionList.add(null); + doGetPartitionsFilterTest( CREATE_TABLE_COLUMNS_PARTITIONED_VARCHAR, PARTITION_KEY, partitionList, - ImmutableList.of(isNullFilter), + ImmutableList.of(new PartitionFilterBuilder() + // IS NULL + .addDomain(PARTITION_KEY, Domain.onlyNull(VarcharType.VARCHAR)) + .build()), ImmutableList.of(ImmutableList.of(GlueExpressionUtil.NULL_STRING))); + + doGetPartitionsFilterTest( + CREATE_TABLE_COLUMNS_PARTITIONED_VARCHAR, + PARTITION_KEY, + partitionList, + ImmutableList.of(new PartitionFilterBuilder() + // IS NULL or is a specific value + .addDomain(PARTITION_KEY, Domain.create(ValueSet.of(VARCHAR, utf8Slice("100")), true)) + .build()), + ImmutableList.of(ImmutableList.of("100", GlueExpressionUtil.NULL_STRING))); + } + + @Test + public void testGetPartitionsFilterEqualsOrIsNullWithValue() + throws Exception + { + TupleDomain equalsOrIsNullFilter = new PartitionFilterBuilder() + .addStringValues(PARTITION_KEY, "2020-03-01") + .addDomain(PARTITION_KEY, Domain.onlyNull(VarcharType.VARCHAR)) + .build(); + List partitionList = new ArrayList<>(); + partitionList.add("2020-01-01"); + partitionList.add("2020-02-01"); + partitionList.add("2020-03-01"); + partitionList.add(null); + + doGetPartitionsFilterTest( + CREATE_TABLE_COLUMNS_PARTITIONED_VARCHAR, + PARTITION_KEY, + partitionList, + ImmutableList.of(equalsOrIsNullFilter), + ImmutableList.of(ImmutableList.of("2020-03-01", GlueExpressionUtil.NULL_STRING))); } @Test @@ -922,6 +956,27 @@ public Object[][] unsupportedNullPushdownTypes() }; } + @Test + public void testGetPartitionsFilterEqualsAndIsNotNull() + throws Exception + { + TupleDomain equalsAndIsNotNullFilter = new PartitionFilterBuilder() + .addDomain(PARTITION_KEY, Domain.notNull(VarcharType.VARCHAR)) + .addBigintValues(PARTITION_KEY2, 300L) + .build(); + + doGetPartitionsFilterTest( + CREATE_TABLE_COLUMNS_PARTITIONED_TWO_KEYS, + ImmutableList.of(PARTITION_KEY, PARTITION_KEY2), + ImmutableList.of( + PartitionValues.make("2020-01-01", "100"), + PartitionValues.make("2020-02-01", "200"), + PartitionValues.make("2020-03-01", "300"), + PartitionValues.make(null, "300")), + ImmutableList.of(equalsAndIsNotNullFilter), + ImmutableList.of(ImmutableList.of(PartitionValues.make("2020-03-01", "300")))); + } + @Test public void testUpdateStatisticsOnCreate() { diff --git a/plugin/trino-hive/src/test/sql/create-test.sql b/plugin/trino-hive/src/test/sql/create-test.sql index 912577f5e2a2..ea3a6950af9e 100644 --- a/plugin/trino-hive/src/test/sql/create-test.sql +++ b/plugin/trino-hive/src/test/sql/create-test.sql @@ -1,3 +1,7 @@ +set hive.exec.dynamic.partition.mode=nonstrict; + +CREATE TABLE dummy (dummy varchar(1)); + CREATE TABLE trino_test_sequence ( n INT ) @@ -26,6 +30,13 @@ COMMENT 'Presto test data' STORED AS TEXTFILE ; +CREATE TABLE trino_test_partitioned_with_null ( + a_value STRING +) +PARTITIONED BY (p_string STRING, p_integer int) +STORED AS TEXTFILE +; + CREATE TABLE trino_test_offline ( t_string STRING ) @@ -124,6 +135,8 @@ LOAD DATA LOCAL INPATH '/docker/files/words' INTO TABLE tmp_trino_test_load ; +INSERT INTO dummy VALUES ('x'); + INSERT OVERWRITE TABLE trino_test_sequence SELECT TRANSFORM(word) USING 'awk "BEGIN { n = 0 } { print ++n }"' AS n @@ -193,6 +206,10 @@ SELECT , 1 + n FROM trino_test_sequence LIMIT 100; +INSERT INTO TABLE trino_test_partitioned_with_null PARTITION (p_string, p_integer) SELECT 'NULL row', NULL, NULL FROM dummy; +INSERT INTO TABLE trino_test_partitioned_with_null PARTITION (p_string, p_integer) SELECT 'value row', 'abc', 123 FROM dummy; +INSERT INTO TABLE trino_test_partitioned_with_null PARTITION (p_string, p_integer) SELECT 'another value row', 'def', 456 FROM dummy; + INSERT INTO TABLE trino_test_offline_partition PARTITION (ds='2012-12-29') SELECT 'test' FROM trino_test_sequence LIMIT 100;