Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -238,53 +238,60 @@ private Expression unwrapCast(ComparisonExpression expression)
Optional<Type.Range> sourceRange = sourceType.getRange();
if (sourceRange.isPresent()) {
Object max = sourceRange.get().getMax();
Object maxInTargetType = coerce(max, sourceToTarget);

// NaN values of `right` are excluded at this point. Otherwise, NaN would be recognized as
// greater than source type upper bound, and incorrect expression might be derived.
int upperBoundComparison = compare(targetType, right, maxInTargetType);
if (upperBoundComparison > 0) {
// larger than maximum representable value
return switch (operator) {
case EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> falseIfNotNull(cast.getExpression());
case NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case IS_DISTINCT_FROM -> TRUE_LITERAL;
};
Object maxInTargetType = null;
try {
maxInTargetType = coerce(max, sourceToTarget);
}

if (upperBoundComparison == 0) {
// equal to max representable value
return switch (operator) {
case GREATER_THAN -> falseIfNotNull(cast.getExpression());
case GREATER_THAN_OR_EQUAL -> new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, max, sourceType));
case LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case LESS_THAN -> new ComparisonExpression(NOT_EQUAL, cast.getExpression(), literalEncoder.toExpression(session, max, sourceType));
case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(session, max, sourceType));
};
catch (RuntimeException e) {
// Coercion may fail e.g. for out of range values, it's not guaranteed to be "saturated"
}
if (maxInTargetType != null) {
// NaN values of `right` are excluded at this point. Otherwise, NaN would be recognized as
// greater than source type upper bound, and incorrect expression might be derived.
int upperBoundComparison = compare(targetType, right, maxInTargetType);
if (upperBoundComparison > 0) {
// larger than maximum representable value
return switch (operator) {
case EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> falseIfNotNull(cast.getExpression());
case NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case IS_DISTINCT_FROM -> TRUE_LITERAL;
};
}

Object min = sourceRange.get().getMin();
Object minInTargetType = coerce(min, sourceToTarget);

int lowerBoundComparison = compare(targetType, right, minInTargetType);
if (lowerBoundComparison < 0) {
// smaller than minimum representable value
return switch (operator) {
case NOT_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> falseIfNotNull(cast.getExpression());
case IS_DISTINCT_FROM -> TRUE_LITERAL;
};
}
if (upperBoundComparison == 0) {
// equal to max representable value
return switch (operator) {
case GREATER_THAN -> falseIfNotNull(cast.getExpression());
case GREATER_THAN_OR_EQUAL -> new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, max, sourceType));
case LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case LESS_THAN -> new ComparisonExpression(NOT_EQUAL, cast.getExpression(), literalEncoder.toExpression(session, max, sourceType));
case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(session, max, sourceType));
};
}

if (lowerBoundComparison == 0) {
// equal to min representable value
return switch (operator) {
case LESS_THAN -> falseIfNotNull(cast.getExpression());
case LESS_THAN_OR_EQUAL -> new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, min, sourceType));
case GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case GREATER_THAN -> new ComparisonExpression(NOT_EQUAL, cast.getExpression(), literalEncoder.toExpression(session, min, sourceType));
case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(session, min, sourceType));
};
Object min = sourceRange.get().getMin();
Object minInTargetType = coerce(min, sourceToTarget);

int lowerBoundComparison = compare(targetType, right, minInTargetType);
if (lowerBoundComparison < 0) {
// smaller than minimum representable value
return switch (operator) {
case NOT_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> falseIfNotNull(cast.getExpression());
case IS_DISTINCT_FROM -> TRUE_LITERAL;
};
}

if (lowerBoundComparison == 0) {
// equal to min representable value
return switch (operator) {
case LESS_THAN -> falseIfNotNull(cast.getExpression());
case LESS_THAN_OR_EQUAL -> new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, min, sourceType));
case GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.getExpression());
case GREATER_THAN -> new ComparisonExpression(NOT_EQUAL, cast.getExpression(), literalEncoder.toExpression(session, min, sourceType));
case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(session, min, sourceType));
};
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions core/trino-main/src/test/java/io/trino/type/TestDateType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.SqlDate;
import io.trino.spi.type.Type.Range;

import static io.trino.spi.type.DateType.DATE;
import static org.testng.Assert.assertEquals;

public class TestDateType
extends AbstractTestType
Expand Down Expand Up @@ -49,4 +51,12 @@ protected Object getGreaterValue(Object value)
{
return ((Long) value) + 1;
}

@Override
public void testRange()
{
Range range = type.getRange().orElseThrow();
assertEquals(range.getMin(), (long) Integer.MIN_VALUE);
assertEquals(range.getMax(), (long) Integer.MAX_VALUE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.Type.Range;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static io.trino.spi.type.TimestampType.TIMESTAMP_NANOS;
import static io.trino.spi.type.TimestampType.createTimestampType;
import static org.testng.Assert.assertEquals;

public class TestLongTimestampType
extends AbstractTestType
Expand Down Expand Up @@ -51,4 +56,33 @@ protected Object getGreaterValue(Object value)
LongTimestamp timestamp = (LongTimestamp) value;
return new LongTimestamp(timestamp.getEpochMicros() + 1, 0);
}

@Override
public void testRange()
{
Range range = type.getRange().orElseThrow();
assertEquals(range.getMin(), new LongTimestamp(Long.MIN_VALUE, 0));
assertEquals(range.getMax(), new LongTimestamp(Long.MAX_VALUE, 999_000));
}

@Test(dataProvider = "testRangeEveryPrecisionDataProvider")
public void testRangeEveryPrecision(int precision, LongTimestamp expectedMax)
{
Range range = createTimestampType(precision).getRange().orElseThrow();
assertEquals(range.getMin(), new LongTimestamp(Long.MIN_VALUE, 0));
assertEquals(range.getMax(), expectedMax);
}

@DataProvider
public static Object[][] testRangeEveryPrecisionDataProvider()
{
return new Object[][] {
{7, new LongTimestamp(Long.MAX_VALUE, 900_000)},
{8, new LongTimestamp(Long.MAX_VALUE, 990_000)},
{9, new LongTimestamp(Long.MAX_VALUE, 999_000)},
{10, new LongTimestamp(Long.MAX_VALUE, 999_900)},
{11, new LongTimestamp(Long.MAX_VALUE, 999_990)},
{12, new LongTimestamp(Long.MAX_VALUE, 999_999)},
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.Type.Range;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampType.createTimestampType;
import static org.testng.Assert.assertEquals;

public class TestShortTimestampType
extends AbstractTestType
Expand Down Expand Up @@ -49,4 +54,34 @@ protected Object getGreaterValue(Object value)
{
return ((Long) value) + 1_000;
}

@Override
public void testRange()
{
Range range = type.getRange().orElseThrow();
assertEquals(range.getMin(), Long.MIN_VALUE + 808);
assertEquals(range.getMax(), Long.MAX_VALUE - 807);
}

@Test(dataProvider = "testRangeEveryPrecisionDataProvider")
public void testRangeEveryPrecision(int precision, long expectedMin, long expectedMax)
{
Range range = createTimestampType(precision).getRange().orElseThrow();
assertEquals(range.getMin(), expectedMin);
assertEquals(range.getMax(), expectedMax);
}

@DataProvider
public static Object[][] testRangeEveryPrecisionDataProvider()
{
return new Object[][] {
{0, Long.MIN_VALUE + 775808, Long.MAX_VALUE - 775807},
{1, Long.MIN_VALUE + 75808, Long.MAX_VALUE - 75807},
{2, Long.MIN_VALUE + 5808, Long.MAX_VALUE - 5807},
{3, Long.MIN_VALUE + 808, Long.MAX_VALUE - 807},
{4, Long.MIN_VALUE + 8, Long.MAX_VALUE - 7},
{5, Long.MIN_VALUE + 8, Long.MAX_VALUE - 7},
{6, Long.MIN_VALUE, Long.MAX_VALUE},
};
}
}
8 changes: 8 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/type/DateType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorSession;

import java.util.Optional;

//
// A date is stored as days from 1970-01-01.
//
Expand Down Expand Up @@ -45,6 +47,12 @@ public Object getObjectValue(ConnectorSession session, Block block, int position
return new SqlDate(days);
}

@Override
public Optional<Range> getRange()
{
return Optional.of(new Range((long) Integer.MIN_VALUE, (long) Integer.MAX_VALUE));
}

@Override
@SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
public boolean equals(Object other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.ScalarOperator;

import java.util.Optional;

import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_LAST;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.LESS_THAN;
import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL;
import static io.trino.spi.function.OperatorType.XX_HASH_64;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.Timestamps.rescale;
import static io.trino.spi.type.TypeOperatorDeclaration.extractOperatorDeclaration;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.lang.invoke.MethodHandles.lookup;

Expand All @@ -43,6 +48,7 @@ class LongTimestampType
extends TimestampType
{
private static final TypeOperatorDeclaration TYPE_OPERATOR_DECLARATION = extractOperatorDeclaration(LongTimestampType.class, lookup(), LongTimestamp.class);
private final Range range;

public LongTimestampType(int precision)
{
Expand All @@ -51,6 +57,10 @@ public LongTimestampType(int precision)
if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) {
throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION));
}

// ShortTimestampType instances are created eagerly and shared so it's OK to precompute some things.
int picosOfMicroMax = toIntExact(PICOSECONDS_PER_MICROSECOND - rescale(1, 0, 12 - getPrecision()));
range = new Range(new LongTimestamp(Long.MIN_VALUE, 0), new LongTimestamp(Long.MAX_VALUE, picosOfMicroMax));
}

@Override
Expand Down Expand Up @@ -148,6 +158,12 @@ private static int getFraction(Block block, int position)
return block.getInt(position, SIZE_OF_LONG);
}

@Override
public Optional<Range> getRange()
{
return Optional.of(range);
}

@ScalarOperator(EQUAL)
private static boolean equalOperator(LongTimestamp left, LongTimestamp right)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.ScalarOperator;

import java.util.Optional;

import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_LAST;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
import static io.trino.spi.function.OperatorType.LESS_THAN;
import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL;
import static io.trino.spi.function.OperatorType.XX_HASH_64;
import static io.trino.spi.type.Timestamps.rescale;
import static io.trino.spi.type.TypeOperatorDeclaration.extractOperatorDeclaration;
import static java.lang.String.format;
import static java.lang.invoke.MethodHandles.lookup;
Expand All @@ -42,6 +45,7 @@ class ShortTimestampType
extends TimestampType
{
private static final TypeOperatorDeclaration TYPE_OPERATOR_DECLARATION = extractOperatorDeclaration(ShortTimestampType.class, lookup(), long.class);
private final Range range;

public ShortTimestampType(int precision)
{
Expand All @@ -50,6 +54,15 @@ public ShortTimestampType(int precision)
if (precision < 0 || precision > MAX_SHORT_PRECISION) {
throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION));
}

// ShortTimestampType instances are created eagerly and shared so it's OK to precompute some things.
if (getPrecision() == MAX_SHORT_PRECISION) {
range = new Range(Long.MIN_VALUE, Long.MAX_VALUE);
}
else {
long max = rescale(rescale(Long.MAX_VALUE, MAX_SHORT_PRECISION, getPrecision()), getPrecision(), MAX_SHORT_PRECISION);
range = new Range(-max, max);
}
}

@Override
Expand Down Expand Up @@ -125,6 +138,12 @@ public Object getObjectValue(ConnectorSession session, Block block, int position
return SqlTimestamp.newInstance(getPrecision(), epochMicros, 0);
}

@Override
public Optional<Range> getRange()
{
return Optional.of(range);
}

@ScalarOperator(EQUAL)
private static boolean equalOperator(long left, long right)
{
Expand Down
Loading