Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import io.trino.spi.predicate.SortedRangeSet;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.DecimalConversions;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Type;
Expand All @@ -48,6 +50,7 @@
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.joda.time.DateTimeZone;
Expand All @@ -68,10 +71,12 @@
import static io.trino.parquet.ParquetTimestampUtils.decodeInt96Timestamp;
import static io.trino.parquet.ParquetTypeUtils.getShortDecimalValue;
import static io.trino.parquet.predicate.PredicateUtils.isStatisticsOverflow;
import static io.trino.parquet.reader.ColumnReaderFactory.isDecimalRescaled;
import static io.trino.plugin.base.type.TrinoTimestampEncoderFactory.createTimestampEncoder;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.Decimals.longTenToNth;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
Expand Down Expand Up @@ -377,11 +382,8 @@ private static Domain getDomain(
SortedRangeSet.Builder rangesBuilder = SortedRangeSet.builder(type, minimums.size());
if (decimalType.isShort()) {
for (int i = 0; i < minimums.size(); i++) {
Object min = minimums.get(i);
Object max = maximums.get(i);

long minValue = min instanceof Slice minSlice ? getShortDecimalValue(minSlice.getBytes()) : asLong(min);
long maxValue = max instanceof Slice maxSlice ? getShortDecimalValue(maxSlice.getBytes()) : asLong(max);
long minValue = getShortDecimal(minimums.get(i), decimalType, column);
long maxValue = getShortDecimal(maximums.get(i), decimalType, column);

if (isStatisticsOverflow(type, minValue, maxValue)) {
return Domain.create(ValueSet.all(type), hasNullValue);
Expand All @@ -392,11 +394,8 @@ private static Domain getDomain(
}
else {
for (int i = 0; i < minimums.size(); i++) {
Object min = minimums.get(i);
Object max = maximums.get(i);

Int128 minValue = min instanceof Slice minSlice ? Int128.fromBigEndian(minSlice.getBytes()) : Int128.valueOf(asLong(min));
Int128 maxValue = max instanceof Slice maxSlice ? Int128.fromBigEndian(maxSlice.getBytes()) : Int128.valueOf(asLong(max));
Int128 minValue = getLongDecimal(minimums.get(i), decimalType, column);
Int128 maxValue = getLongDecimal(maximums.get(i), decimalType, column);

rangesBuilder.addRangeInclusive(minValue, maxValue);
}
Expand Down Expand Up @@ -494,6 +493,61 @@ private static Domain getDomain(
return Domain.create(ValueSet.all(type), hasNullValue);
}

private static long getShortDecimal(Object value, DecimalType columnType, ColumnDescriptor column)
{
LogicalTypeAnnotation annotation = column.getPrimitiveType().getLogicalTypeAnnotation();

if (annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation) {
if (isDecimalRescaled(decimalAnnotation, columnType)) {
if (decimalAnnotation.getPrecision() <= Decimals.MAX_SHORT_PRECISION) {
long rescale = longTenToNth(Math.abs(columnType.getScale() - decimalAnnotation.getScale()));
return DecimalConversions.shortToShortCast(
value instanceof Slice slice ? getShortDecimalValue(slice.getBytes()) : asLong(value),
decimalAnnotation.getPrecision(),
decimalAnnotation.getScale(),
columnType.getPrecision(),
columnType.getScale(),
rescale,
rescale / 2);
}
Int128 int128Representation = value instanceof Slice minSlice ? Int128.fromBigEndian(minSlice.getBytes()) : Int128.valueOf(asLong(value));
return DecimalConversions.longToShortCast(
int128Representation,
decimalAnnotation.getPrecision(),
decimalAnnotation.getScale(),
columnType.getPrecision(),
columnType.getScale());
}
}
return value instanceof Slice slice ? getShortDecimalValue(slice.getBytes()) : asLong(value);
}

private static Int128 getLongDecimal(Object value, DecimalType columnType, ColumnDescriptor column)
{
LogicalTypeAnnotation annotation = column.getPrimitiveType().getLogicalTypeAnnotation();

if (annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation) {
if (isDecimalRescaled(decimalAnnotation, columnType)) {
if (decimalAnnotation.getPrecision() <= Decimals.MAX_SHORT_PRECISION) {
return DecimalConversions.shortToLongCast(
value instanceof Slice slice ? getShortDecimalValue(slice.getBytes()) : asLong(value),
decimalAnnotation.getPrecision(),
decimalAnnotation.getScale(),
columnType.getPrecision(),
columnType.getScale());
}
Int128 int128Representation = value instanceof Slice slice ? Int128.fromBigEndian(slice.getBytes()) : Int128.valueOf(asLong(value));
return DecimalConversions.longToLongCast(
int128Representation,
decimalAnnotation.getPrecision(),
decimalAnnotation.getScale(),
columnType.getPrecision(),
columnType.getScale());
}
}
return value instanceof Slice slice ? Int128.fromBigEndian(slice.getBytes()) : Int128.valueOf(asLong(value));
}

@VisibleForTesting
public static Domain getDomain(
Type type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ public Optional<Boolean> visit(UUIDLogicalTypeAnnotation uuidLogicalType)
.orElse(FALSE);
}

private static boolean isDecimalRescaled(DecimalLogicalTypeAnnotation decimalAnnotation, DecimalType trinoType)
public static boolean isDecimalRescaled(DecimalLogicalTypeAnnotation decimalAnnotation, DecimalType trinoType)
{
return decimalAnnotation.getPrecision() != trinoType.getPrecision()
|| decimalAnnotation.getScale() != trinoType.getScale();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.parquet.ParquetEncoding.PLAIN_DICTIONARY;
import static io.trino.parquet.ParquetTimestampUtils.JULIAN_EPOCH_OFFSET_DAYS;
import static io.trino.parquet.ParquetTypeUtils.paddingBigInteger;
import static io.trino.parquet.predicate.TupleDomainParquetPredicate.getDomain;
import static io.trino.spi.predicate.Domain.all;
import static io.trino.spi.predicate.Domain.create;
Expand Down Expand Up @@ -100,6 +101,7 @@
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.apache.parquet.schema.LogicalTypeAnnotation.decimalType;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;
Expand Down Expand Up @@ -237,6 +239,25 @@ public void testShortDecimal()
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int32 ShortDecimalColumn\": [min: 100, max: 10, num_nulls: 0] [testFile]");
}

@Test
public void testShortDecimalWithInt64()
throws Exception
{
ColumnDescriptor columnDescriptor = createColumnDescriptor(INT64, "ShortDecimalColumn");
Type type = createDecimalType(5, 2);
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));

assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(10012L, 10012L), ID, UTC)).isEqualTo(singleValue(type, 10012L));
// Test that statistics overflowing the size of the type are not used
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(100012L, 100012L), ID, UTC)).isEqualTo(notNull(type));

assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
// fail on corrupted statistics
assertThatExceptionOfType(ParquetCorruptionException.class)
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, longColumnStats(100L, 10L), ID, UTC))
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int64 ShortDecimalColumn\": [min: 100, max: 10, num_nulls: 0] [testFile]");
}

@Test
public void testShortDecimalWithNoScale()
throws Exception
Expand All @@ -256,6 +277,43 @@ public void testShortDecimalWithNoScale()
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int32 ShortDecimalColumnWithNoScale\": [min: 100, max: 10, num_nulls: 0] [testFile]");
}

@Test
public void testShortDecimalWithLongDecimalAnnotation()
throws Exception
{
ColumnDescriptor columnDescriptor = createColumnDescriptor(FIXED_LEN_BYTE_ARRAY, decimalType(2, 38), "ShortDecimalColumnWithDecimalAnnotation");
BigInteger maximum = new BigInteger("12345");

Type type = createDecimalType(5, 2);
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(maximum, maximum), ID, UTC)).isEqualTo(singleValue(type, 12345L));

assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
assertThat(getDomain(columnDescriptor, type, 10, intColumnStats(0, 100), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));

type = createDecimalType(15, 2);
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(maximum, maximum), ID, UTC)).isEqualTo(singleValue(type, 12345L));

assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
assertThat(getDomain(columnDescriptor, type, 10, intColumnStats(0, 100), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));

Type typeWithDifferentScale = createDecimalType(5, 1);
assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 0, null, ID, UTC)).isEqualTo(all(typeWithDifferentScale));

assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 10, longColumnStats(10012L, 10012L), ID, UTC)).isEqualTo(singleValue(typeWithDifferentScale, 1001L));

// Test that statistics overflowing the size of the type are not used
assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 10, longColumnStats(100012L, 100012L), ID, UTC)).isEqualTo(singleValue(typeWithDifferentScale, 10001L));

assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 10, longColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(typeWithDifferentScale, 0L, true, 10L, true)), false));

// fail on higher precision values
assertThatExceptionOfType(ParquetCorruptionException.class)
.isThrownBy(() -> getDomain(columnDescriptor, createDecimalType(4, 2), 10, binaryColumnStats(maximum, maximum), ID, UTC))
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) ShortDecimalColumnWithDecimalAnnotation (DECIMAL(38,2))\": [min: 0x00000000000000000000000000003039, max: 0x00000000000000000000000000003039, num_nulls: 0] [testFile]");
}

@Test
public void testLongDecimal()
throws Exception
Expand All @@ -277,7 +335,7 @@ public void testLongDecimal()
// fail on corrupted statistics
assertThatExceptionOfType(ParquetCorruptionException.class)
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, binaryColumnStats(100L, 10L), ID, UTC))
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) LongDecimalColumn\": [min: 0x64, max: 0x0A, num_nulls: 0] [testFile]");
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) LongDecimalColumn\": [min: 0x00000000000000000000000000000064, max: 0x0000000000000000000000000000000A, num_nulls: 0] [testFile]");
}

@Test
Expand All @@ -296,7 +354,50 @@ public void testLongDecimalWithNoScale()
// fail on corrupted statistics
assertThatExceptionOfType(ParquetCorruptionException.class)
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, binaryColumnStats(100L, 10L), ID, UTC))
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) LongDecimalColumnWithNoScale\": [min: 0x64, max: 0x0A, num_nulls: 0] [testFile]");
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) LongDecimalColumnWithNoScale\": [min: 0x00000000000000000000000000000064, max: 0x0000000000000000000000000000000A, num_nulls: 0] [testFile]");
}

@Test
public void testLongDecimalWithShortDecimalAnnotation()
throws Exception
{
ColumnDescriptor columnDescriptor = createColumnDescriptor(INT32, decimalType(2, 5), "ShortDecimalColumn");
DecimalType type = createDecimalType(20, 2);

assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(10012L, 10012L), ID, UTC)).isEqualTo(singleValue(type, Int128.valueOf(10012L)));

assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0L, 10012L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, Int128.valueOf(0L), true, Int128.valueOf(10012L), true)), false));
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, Int128.valueOf(0L), true, Int128.valueOf(100L), true)), false));

// fail on corrupted statistics
assertThatExceptionOfType(ParquetCorruptionException.class)
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, longColumnStats(100L, 10L), ID, UTC))
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int32 ShortDecimalColumn (DECIMAL(5,2))\": [min: 100, max: 10, num_nulls: 0] [testFile]");
}

@Test
public void testLongDecimalWithInt64DecimalAnnotation()
throws Exception
{
ColumnDescriptor columnDescriptor = createColumnDescriptor(INT64, decimalType(2, 5), "ShortDecimalColumn");
DecimalType type = createDecimalType(20, 2);
BigInteger maximum = new BigInteger("12345");

Int128 zero = Int128.ZERO;
Int128 hundred = Int128.valueOf(100L);
Int128 max = Int128.valueOf(maximum);

assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(maximum.longValue(), maximum.longValue()), ID, UTC)).isEqualTo(singleValue(type, max));

assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, zero, true, hundred, true)), false));
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0, 100), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, zero, true, hundred, true)), false));

// fail on corrupted statistics
assertThatExceptionOfType(ParquetCorruptionException.class)
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, longColumnStats(100L, 10L), ID, UTC))
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int64 ShortDecimalColumn (DECIMAL(5,2))\": [min: 100, max: 10, num_nulls: 0] [testFile]");
}

@Test
Expand Down Expand Up @@ -724,6 +825,11 @@ private ColumnDescriptor createColumnDescriptor(PrimitiveTypeName typeName, Stri
return new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, typeName, columnName), 0, 0);
}

private ColumnDescriptor createColumnDescriptor(PrimitiveTypeName typeName, LogicalTypeAnnotation typeAnnotation, String columnName)
{
return new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, typeName, columnName).withLogicalTypeAnnotation(typeAnnotation), 0, 0);
}

private TupleDomain<ColumnDescriptor> getEffectivePredicate(ColumnDescriptor column, VarcharType type, Slice value)
{
ColumnDescriptor predicateColumn = new ColumnDescriptor(column.getPath(), column.getPrimitiveType(), 0, 0);
Expand Down Expand Up @@ -786,8 +892,8 @@ private static BinaryStatistics binaryColumnStats(long minimum, long maximum)
private static BinaryStatistics binaryColumnStats(BigInteger minimum, BigInteger maximum)
{
return (BinaryStatistics) Statistics.getBuilderForReading(Types.optional(BINARY).named("BinaryColumn"))
.withMin(minimum.toByteArray())
.withMax(maximum.toByteArray())
.withMin(paddingBigInteger(minimum, 16))
.withMax(paddingBigInteger(maximum, 16))
.withNumNulls(0)
.build();
}
Expand Down