Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@
import static io.prestosql.spi.predicate.TupleDomain.withColumnDomains;
import static io.prestosql.spi.statistics.TableStatisticType.ROW_COUNT;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN;
import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -2217,8 +2218,9 @@ private static Domain buildColumnDomain(ColumnHandle column, List<HivePartition>
checkArgument(!partitions.isEmpty(), "partitions cannot be empty");

boolean hasNull = false;
boolean hasNaN = false;
List<Object> nonNullValues = new ArrayList<>();
Type type = null;
Type type = ((HiveColumnHandle) column).getType();

for (HivePartition partition : partitions) {
NullableValue value = partition.getKeys().get(column);
Expand All @@ -2230,24 +2232,29 @@ private static Domain buildColumnDomain(ColumnHandle column, List<HivePartition>
hasNull = true;
}
else {
if (isFloatingPointNaN(type, value.getValue())) {
hasNaN = true;
}
nonNullValues.add(value.getValue());
}

if (type == null) {
type = value.getType();
}
}

if (!nonNullValues.isEmpty()) {
Domain domain = Domain.multipleValues(type, nonNullValues);
if (hasNull) {
return domain.union(Domain.onlyNull(type));
}
Domain domain;
if (nonNullValues.isEmpty()) {
domain = Domain.none(type);
}
else if (hasNaN) {
domain = Domain.notNull(type);
}
else {
domain = Domain.multipleValues(type, nonNullValues);
}

return domain;
if (hasNull) {
domain = domain.union(Domain.onlyNull(type));
}

return Domain.onlyNull(type);
return domain;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,29 @@ public void testJoinPartitionedWithMissingPartitionFilter()
assertUpdate(session, "DROP TABLE partition_test8");
}

@Test
public void testNaNPartition()
{
assertUpdate("DROP TABLE IF EXISTS test_nan_partition");
assertUpdate("CREATE TABLE test_nan_partition(a varchar, d double) WITH (partitioned_by = ARRAY['d'])");
assertUpdate("INSERT INTO test_nan_partition VALUES ('a', 42e0), ('b', nan())", 2);

assertQuery(
"SELECT a, d, regexp_replace(\"$path\", '.*(/[^/]*/[^/]*/)[^/]*', '...$1...') FROM test_nan_partition",
"VALUES " +
" ('a', 42, '.../test_nan_partition/d=42.0/...'), " +
" ('b', SQRT(-1), '.../test_nan_partition/d=NaN/...')"); // SQRT(-1) is H2's recommended way to obtain NaN

assertQueryReturnsEmptyResult("SELECT a FROM test_nan_partition JOIN (VALUES 33e0) u(x) ON d = x");
assertQueryReturnsEmptyResult("SELECT a FROM test_nan_partition JOIN (VALUES 33e0) u(x) ON d = x OR rand() = 42");

assertQuery(
"SELECT * FROM test_nan_partition t1 JOIN test_nan_partition t2 ON t1.d = t2.d",
"VALUES ('a', 42, 'a', 42)");

assertUpdate("DROP TABLE test_nan_partition");
}

@Test
public void testJoinWithPartitionFilterOnPartionedTable()
{
Expand Down Expand Up @@ -3297,7 +3320,8 @@ private void testCreateExternalTable(
}

@Test
public void testCreateExternalTable() throws Exception
public void testCreateExternalTable()
throws Exception
{
testCreateExternalTable(
"test_create_external",
Expand All @@ -3307,7 +3331,8 @@ public void testCreateExternalTable() throws Exception
}

@Test
public void testCreateExternalTableWithFieldSeparator() throws Exception
public void testCreateExternalTableWithFieldSeparator()
throws Exception
{
testCreateExternalTable(
"test_create_external",
Expand All @@ -3317,7 +3342,8 @@ public void testCreateExternalTableWithFieldSeparator() throws Exception
}

@Test
public void testCreateExternalTableWithFieldSeparatorEscape() throws Exception
public void testCreateExternalTableWithFieldSeparatorEscape()
throws Exception
{
testCreateExternalTable(
"test_create_external_text_file_with_field_separator_and_escape",
Expand All @@ -3329,7 +3355,8 @@ public void testCreateExternalTableWithFieldSeparatorEscape() throws Exception
}

@Test
public void testCreateExternalTableWithNullFormat() throws Exception
public void testCreateExternalTableWithNullFormat()
throws Exception
{
testCreateExternalTable(
"test_create_external_textfile_with_null_format",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,23 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.predicate.NullableValue;
import io.prestosql.spi.predicate.ValueSet;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.utf8Slice;
import static io.prestosql.plugin.hive.HiveColumnHandle.createBaseColumn;
import static io.prestosql.plugin.hive.HiveMetadata.createPredicate;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNull;

public class TestHiveMetadata
{
Expand All @@ -36,6 +43,14 @@ public class TestHiveMetadata
HiveColumnHandle.ColumnType.PARTITION_KEY,
Optional.empty());

private static final HiveColumnHandle DOUBLE_COLUMN_HANDLE = createBaseColumn(
"test",
0,
HiveType.HIVE_DOUBLE,
DOUBLE,
HiveColumnHandle.ColumnType.PARTITION_KEY,
Optional.empty());

@Test(timeOut = 10_000)
public void testCreatePredicate()
{
Expand All @@ -45,10 +60,17 @@ public void testCreatePredicate()
partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
Integer.toString(i),
ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, Slices.utf8Slice(Integer.toString(i))))));
ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, utf8Slice(Integer.toString(i))))));
}

createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build());
Domain domain = createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build())
.getDomains().orElseThrow().get(TEST_COLUMN_HANDLE);
assertEquals(domain, Domain.create(
ValueSet.copyOf(VARCHAR,
IntStream.range(0, 5_000)
.mapToObj(i -> utf8Slice(Integer.toString(i)))
.collect(toImmutableList())),
false));
}

@Test
Expand All @@ -63,7 +85,56 @@ public void testCreateOnlyNullsPredicate()
ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.asNull(VARCHAR))));
}

createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build());
Domain domain = createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build())
.getDomains().orElseThrow().get(TEST_COLUMN_HANDLE);
assertEquals(domain, Domain.onlyNull(VARCHAR));
}

@Test
public void testCreatePredicateWithNaN()
{
HiveColumnHandle columnHandle = DOUBLE_COLUMN_HANDLE;
ImmutableList.Builder<HivePartition> partitions = ImmutableList.builder();

partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
"p1",
ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, Double.NaN))));

partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
"p2",
ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, 4.2))));

Domain domain = createPredicate(ImmutableList.of(columnHandle), partitions.build())
.getDomains().orElseThrow().get(columnHandle);
assertEquals(domain, Domain.notNull(DOUBLE));
}

@Test
public void testCreatePredicateWithNaNAndNull()
{
HiveColumnHandle columnHandle = DOUBLE_COLUMN_HANDLE;
ImmutableList.Builder<HivePartition> partitions = ImmutableList.builder();

partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
"p1",
ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, Double.NaN))));

partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
"p2",
ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, 4.2))));

partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
"p3",
ImmutableMap.of(columnHandle, NullableValue.asNull(DOUBLE))));

Domain domain = createPredicate(ImmutableList.of(columnHandle), partitions.build())
.getDomains().orElseThrow().get(columnHandle);
assertNull(domain);
}

@Test
Expand All @@ -75,14 +146,16 @@ public void testCreateMixedPredicate()
partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
Integer.toString(i),
ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, Slices.utf8Slice(Integer.toString(i))))));
ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, utf8Slice(Integer.toString(i))))));
}

partitions.add(new HivePartition(
new SchemaTableName("test", "test"),
"null",
ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.asNull(VARCHAR))));

createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build());
Domain domain = createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build())
.getDomains().orElseThrow().get(TEST_COLUMN_HANDLE);
assertEquals(domain, Domain.create(ValueSet.of(VARCHAR, utf8Slice("0"), utf8Slice("1"), utf8Slice("2"), utf8Slice("3"), utf8Slice("4")), true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;

Expand Down Expand Up @@ -253,9 +254,13 @@ private Domain convertToDomain(Type type, Block block)
for (int position = 0; position < block.getPositionCount(); ++position) {
Object value = TypeUtils.readNativeValue(type, block, position);
if (value != null) {
values.add(value);
// join doesn't match rows with NaN values.
if (!isFloatingPointNaN(type, value)) {
values.add(value);
}
}
}

// Inner and right join doesn't match rows with null key column values.
return Domain.create(ValueSet.copyOf(type, values.build()), false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
import static io.airlift.slice.SliceUtf8.setCodePointAt;
import static io.prestosql.spi.function.OperatorType.SATURATED_FLOOR_CAST;
import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN;
import static io.prestosql.sql.ExpressionUtils.and;
import static io.prestosql.sql.ExpressionUtils.combineConjuncts;
import static io.prestosql.sql.ExpressionUtils.combineDisjunctsWithDefault;
Expand All @@ -89,8 +90,6 @@
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
Expand Down Expand Up @@ -621,8 +620,7 @@ private static Optional<Domain> extractOrderableDomain(ComparisonExpression.Oper
}

// Handle comparisons against NaN
if ((type instanceof DoubleType && Double.isNaN((double) value)) ||
(type instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) value))))) {
if (isFloatingPointNaN(type, value)) {
switch (comparisonOperator) {
case EQUAL:
case GREATER_THAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.stream.Collectors;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN;
import static io.prestosql.sql.ExpressionUtils.combineConjuncts;
import static io.prestosql.sql.ExpressionUtils.expressionOrNullSymbols;
import static io.prestosql.sql.ExpressionUtils.extractConjuncts;
Expand Down Expand Up @@ -332,6 +333,7 @@ public Expression visitValues(ValuesNode node, Void context)

ImmutableList.Builder<Object> builder = ImmutableList.builder();
boolean hasNull = false;
boolean hasNaN = false;
boolean nonDeterministic = false;
for (int row = 0; row < node.getRows().size(); row++) {
Expression value = node.getRows().get(row).get(column);
Expand All @@ -352,6 +354,9 @@ public Expression visitValues(ValuesNode node, Void context)
hasNull = true;
}
else {
if (isFloatingPointNaN(type, evaluated)) {
hasNaN = true;
}
builder.add(evaluated);
}
}
Expand All @@ -364,10 +369,15 @@ public Expression visitValues(ValuesNode node, Void context)

List<Object> values = builder.build();

Domain domain = Domain.none(type);

if (!values.isEmpty()) {
domain = domain.union(Domain.multipleValues(type, values));
Domain domain;
if (values.isEmpty()) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle empty as special case? Should Domain.multipleValues handle that
in order to reduce iffs?
For example:

Domain.create(ValueSet.copyOf(type, ImmutableList.of()), false).isNone() == true

Yet multipleValues is implemented as

    public static Domain multipleValues(Type type, List<?> values)
    {
        if (values.isEmpty()) {
            throw new IllegalArgumentException("values cannot be empty");
        }
        if (values.size() == 1) {
            return singleValue(type, values.get(0));
        }
        return new Domain(ValueSet.of(type, values.get(0), values.subList(1, values.size()).toArray()), false);
    }

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do not have a strong opinion. There are exactly 3 non-test call sites for multipleValues, so updating it in a follow up would not be a problem.

domain = Domain.none(type);
}
else if (hasNaN) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding comment. It might not be obvious for reader why NaN makes TupleDomain accept all non-null values.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This knowledge is sprinkled in multiple places in the code base already.
There seems to be no good place to place such a comment though.

domain = Domain.notNull(type);
}
else {
domain = Domain.multipleValues(type, values);
}

if (hasNull) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.IntegerType.INTEGER;
import static io.prestosql.spi.type.RealType.REAL;
import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN;
import static io.prestosql.sql.ExpressionUtils.and;
import static io.prestosql.sql.ExpressionUtils.or;
import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType;
Expand Down Expand Up @@ -188,7 +189,7 @@ private Expression unwrapCast(ComparisonExpression expression)

// Handle comparison against NaN.
// It must be done before source type range bounds are compared to target value.
if ((targetType instanceof DoubleType && Double.isNaN((double) right)) || (targetType instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) right))))) {
if (isFloatingPointNaN(targetType, right)) {
switch (operator) {
case EQUAL:
case GREATER_THAN:
Expand Down
Loading