Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.math.BigDecimal;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.Decimals.encodeScaledValue;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
Expand Down Expand Up @@ -132,7 +134,12 @@ else if (type.getJavaType() == double.class) {
type.writeDouble(blockBuilder, ((Number) value).doubleValue());
}
else if (type.getJavaType() == long.class) {
type.writeLong(blockBuilder, ((Number) value).longValue());
if (value instanceof BigDecimal) {
type.writeLong(blockBuilder, ((BigDecimal) value).unscaledValue().longValue());
}
else {
type.writeLong(blockBuilder, ((Number) value).longValue());
}
}
else if (type.getJavaType() == Slice.class) {
Slice slice;
Expand All @@ -142,6 +149,9 @@ else if (type.getJavaType() == Slice.class) {
else if (value instanceof String) {
slice = Slices.utf8Slice((String) value);
}
else if (value instanceof BigDecimal) {
slice = encodeScaledValue((BigDecimal) value);
}
else {
slice = (Slice) value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types.DecimalType;

import java.io.IOException;
import java.io.StringWriter;
Expand Down Expand Up @@ -168,7 +169,7 @@ public static Object getValue(JsonNode partitionValue, Type type)
throw new UncheckedIOException("Failed during JSON conversion of " + partitionValue, e);
}
case DECIMAL:
return partitionValue.decimalValue();
return partitionValue.decimalValue().setScale(((DecimalType) type).scale());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this preserve the precision?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't, and the precision would always be 0 in the result value of BigDecimal. But precision of 0 in BigDecimal is ok as it is not a strictly required parameter as scale. It could be gotten by function precision() if necessary.

}
throw new UnsupportedOperationException("Type not supported as partition column: " + type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,147 @@ private void testCreatePartitionedTableAs(Session session, FileFormat fileFormat
dropTable(session, "test_create_partitioned_table_as_" + fileFormat.toString().toLowerCase(ENGLISH));
}

@Test
public void testPartitionOnDecimalColumn()
{
testWithAllFileFormats(this::testPartitionedByShortDecimalType);
testWithAllFileFormats(this::testPartitionedByLongDecimalType);
testWithAllFileFormats(this::testTruncateShortDecimalTransform);
testWithAllFileFormats(this::testTruncateLongDecimalTransform);
}

public void testPartitionedByShortDecimalType(Session session, FileFormat format)
{
// create iceberg table partitioned by column of ShortDecimalType, and insert some data
assertUpdate(session, "drop table if exists test_partition_columns_short_decimal");
assertUpdate(session, format("create table test_partition_columns_short_decimal(a bigint, b decimal(9, 2))" +
" with (format = '%s', partitioning = ARRAY['b'])", format.name()));
assertUpdate(session, "insert into test_partition_columns_short_decimal values(1, 12.31), (2, 133.28)", 2);
assertQuery(session, "select * from test_partition_columns_short_decimal", "values(1, 12.31), (2, 133.28)");

// validate column of ShortDecimalType exists in query filter
assertQuery(session, "select * from test_partition_columns_short_decimal where b = 133.28", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_short_decimal where b = 12.31", "values(1, 12.31)");

// validate column of ShortDecimalType in system table "partitions"
assertQuery(session, "select b, row_count from \"test_partition_columns_short_decimal$partitions\"", "values(12.31, 1), (133.28, 1)");

// validate column of TimestampType exists in delete filter
assertUpdate(session, "delete from test_partition_columns_short_decimal WHERE b = 12.31", 1);
assertQuery(session, "select * from test_partition_columns_short_decimal", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_short_decimal where b = 133.28", "values(2, 133.28)");

assertQuery(session, "select b, row_count from \"test_partition_columns_short_decimal$partitions\"", "values(133.28, 1)");

assertUpdate(session, "drop table test_partition_columns_short_decimal");
}

public void testPartitionedByLongDecimalType(Session session, FileFormat format)
{
// create iceberg table partitioned by column of ShortDecimalType, and insert some data
assertUpdate(session, "drop table if exists test_partition_columns_long_decimal");
assertUpdate(session, format("create table test_partition_columns_long_decimal(a bigint, b decimal(20, 2))" +
" with (format = '%s', partitioning = ARRAY['b'])", format.name()));
assertUpdate(session, "insert into test_partition_columns_long_decimal values(1, 11111111111111112.31), (2, 133.28)", 2);
assertQuery(session, "select * from test_partition_columns_long_decimal", "values(1, 11111111111111112.31), (2, 133.28)");

// validate column of ShortDecimalType exists in query filter
assertQuery(session, "select * from test_partition_columns_long_decimal where b = 133.28", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_long_decimal where b = 11111111111111112.31", "values(1, 11111111111111112.31)");

// validate column of ShortDecimalType in system table "partitions"
assertQuery(session, "select b, row_count from \"test_partition_columns_long_decimal$partitions\"",
"values(11111111111111112.31, 1), (133.28, 1)");

// validate column of TimestampType exists in delete filter
assertUpdate(session, "delete from test_partition_columns_long_decimal WHERE b = 11111111111111112.31", 1);
assertQuery(session, "select * from test_partition_columns_long_decimal", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_long_decimal where b = 133.28", "values(2, 133.28)");

assertQuery(session, "select b, row_count from \"test_partition_columns_long_decimal$partitions\"",
"values(133.28, 1)");

assertUpdate(session, "drop table test_partition_columns_long_decimal");
}

public void testTruncateShortDecimalTransform(Session session, FileFormat format)
{
assertUpdate(session, format("CREATE TABLE test_truncate_decimal_transform (d DECIMAL(9, 2), b BIGINT)" +
" WITH (format = '%s', partitioning = ARRAY['truncate(d, 10)'])", format.name()));
String select = "SELECT d_trunc, row_count, d.min, d.max FROM \"test_truncate_decimal_transform$partitions\"";

assertUpdate(session, "INSERT INTO test_truncate_decimal_transform VALUES" +
"(NULL, 101)," +
"(12.34, 1)," +
"(12.30, 2)," +
"(12.29, 3)," +
"(0.05, 4)," +
"(-0.05, 5)", 6);

assertQuery(session, "SELECT d_trunc FROM \"test_truncate_decimal_transform$partitions\"", "VALUES NULL, 12.30, 12.20, 0.00, -0.10");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d IN (12.34, 12.30)", "VALUES 1, 2");
assertQuery(session, select + " WHERE d_trunc = 12.30",
"VALUES (12.30, 2, 12.30, 12.34)");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = 12.29", "VALUES 3");
assertQuery(session, select + " WHERE d_trunc = 12.20",
"VALUES (12.20, 1, 12.29, 12.29)");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = 0.05", "VALUES 4");
assertQuery(session, select + " WHERE d_trunc = 0.00",
"VALUES (0.00, 1, 0.05, 0.05)");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = -0.05", "VALUES 5");
assertQuery(session, select + " WHERE d_trunc = -0.10",
"VALUES (-0.10, 1, -0.05, -0.05)");

// Exercise IcebergMetadata.applyFilter with non-empty Constraint.predicate, via non-pushdownable predicates
assertQuery(session, "SELECT * FROM test_truncate_decimal_transform WHERE d * 100 % 10 = 9 AND b % 7 = 3",
"VALUES (12.29, 3)");

assertUpdate(session, "DROP TABLE test_truncate_decimal_transform");
}

public void testTruncateLongDecimalTransform(Session session, FileFormat format)
{
assertUpdate(session, format("CREATE TABLE test_truncate_long_decimal_transform (d DECIMAL(20, 2), b BIGINT)" +
" WITH (format = '%s', partitioning = ARRAY['truncate(d, 10)'])", format.name()));
String select = "SELECT d_trunc, row_count, d.min, d.max FROM \"test_truncate_long_decimal_transform$partitions\"";

assertUpdate(session, "INSERT INTO test_truncate_long_decimal_transform VALUES" +
"(NULL, 101)," +
"(12.34, 1)," +
"(12.30, 2)," +
"(11111111111111112.29, 3)," +
"(0.05, 4)," +
"(-0.05, 5)", 6);

assertQuery(session, "SELECT d_trunc FROM \"test_truncate_long_decimal_transform$partitions\"", "VALUES NULL, 12.30, 11111111111111112.20, 0.00, -0.10");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d IN (12.34, 12.30)", "VALUES 1, 2");
assertQuery(session, select + " WHERE d_trunc = 12.30",
"VALUES (12.30, 2, 12.30, 12.34)");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = 11111111111111112.29", "VALUES 3");
assertQuery(session, select + " WHERE d_trunc = 11111111111111112.20",
"VALUES (11111111111111112.20, 1, 11111111111111112.29, 11111111111111112.29)");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = 0.05", "VALUES 4");
assertQuery(session, select + " WHERE d_trunc = 0.00",
"VALUES (0.00, 1, 0.05, 0.05)");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = -0.05", "VALUES 5");
assertQuery(session, select + " WHERE d_trunc = -0.10",
"VALUES (-0.10, 1, -0.05, -0.05)");

// Exercise IcebergMetadata.applyFilter with non-empty Constraint.predicate, via non-pushdownable predicates
assertQuery(session, "SELECT * FROM test_truncate_long_decimal_transform WHERE d * 100 % 10 = 9 AND b % 7 = 3",
"VALUES (11111111111111112.29, 3)");

assertUpdate(session, "DROP TABLE test_truncate_long_decimal_transform");
}

@Test
public void testColumnComments()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
package com.facebook.presto.spi;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.Type;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -28,6 +30,7 @@
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.Decimals.encodeScaledValue;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
Expand Down Expand Up @@ -118,7 +121,15 @@ public long getLong(int field)
{
checkState(record != null, "no current record");
requireNonNull(record.get(field), "value is null");
return ((Number) record.get(field)).longValue();
Object value = record.get(field);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a preconditions check to ensure we don't return just the unscaled value of a LongDecimalType?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, good suggestion. Fixed!

if (value instanceof BigDecimal) {
checkState(((DecimalType) this.getType(field)).isShort(),
"Expected ShortDecimalType");
return ((BigDecimal) value).unscaledValue().longValue();
}
else {
return ((Number) record.get(field)).longValue();
}
}

@Override
Expand All @@ -144,6 +155,9 @@ public Slice getSlice(int field)
if (value instanceof Slice) {
return (Slice) value;
}
if (value instanceof BigDecimal) {
return encodeScaledValue((BigDecimal) value);
}
throw new IllegalArgumentException("Field " + field + " is not a String, but is a " + value.getClass().getName());
}

Expand Down