Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions core/trino-main/src/main/java/io/trino/type/DecimalCasts.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.VarcharType;
import io.trino.util.JsonCastException;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;

import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.operator.scalar.JsonOperators.JSON_FACTORY;
Expand All @@ -47,11 +47,11 @@
import static io.trino.spi.function.OperatorType.CAST;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.Decimals.bigIntegerTenToNth;
import static io.trino.spi.type.Decimals.isShortDecimal;
import static io.trino.spi.type.Decimals.longTenToNth;
import static io.trino.spi.type.Decimals.overflows;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.Int128.ZERO;
import static io.trino.spi.type.Int128Math.multiply;
import static io.trino.spi.type.Int128Math.rescale;
import static io.trino.spi.type.IntegerType.INTEGER;
Expand All @@ -69,7 +69,6 @@
import static java.lang.Math.multiplyExact;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.math.BigInteger.ZERO;
import static java.math.RoundingMode.HALF_UP;
import static java.nio.charset.StandardCharsets.UTF_8;

Expand Down Expand Up @@ -109,12 +108,12 @@ private static SqlScalarFunction castFunctionFromDecimalTo(TypeSignature to, Str
.withExtraParameters((context) -> {
long precision = context.getLiteral("precision");
long scale = context.getLiteral("scale");
Number tenToScale;
Object tenToScale;
if (isShortDecimal(context.getParameterTypes().get(0))) {
tenToScale = longTenToNth(DecimalConversions.intScale(scale));
}
else {
tenToScale = bigIntegerTenToNth(DecimalConversions.intScale(scale));
tenToScale = Int128Math.POWERS_OF_TEN[DecimalConversions.intScale(scale)];
}
return ImmutableList.of(precision, scale, tenToScale);
})))
Expand Down Expand Up @@ -143,12 +142,12 @@ private static SqlScalarFunction castFunctionToDecimalFromBuilder(TypeSignature
.methods(methodNames)
.withExtraParameters((context) -> {
DecimalType resultType = (DecimalType) context.getReturnType();
Number tenToScale;
Object tenToScale;
if (isShortDecimal(resultType)) {
tenToScale = longTenToNth(resultType.getScale());
}
else {
tenToScale = bigIntegerTenToNth(resultType.getScale());
tenToScale = Int128Math.POWERS_OF_TEN[resultType.getScale()];
}
return ImmutableList.of(resultType.getPrecision(), resultType.getScale(), tenToScale);
}))).build();
Expand Down Expand Up @@ -187,7 +186,7 @@ public static boolean shortDecimalToBoolean(long decimal, long precision, long s
}

@UsedByGeneratedCode
public static boolean longDecimalToBoolean(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static boolean longDecimalToBoolean(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
return !decimal.isZero();
}
Expand All @@ -199,10 +198,9 @@ public static long booleanToShortDecimal(boolean value, long precision, long sca
}

@UsedByGeneratedCode
public static Int128 booleanToLongDecimal(boolean value, long precision, long scale, BigInteger tenToScale)
public static Int128 booleanToLongDecimal(boolean value, long precision, long scale, Int128 tenToScale)
{
BigInteger unscaledValue = value ? tenToScale : ZERO;
return Int128.valueOf(unscaledValue);
return value ? tenToScale : ZERO;
}

@UsedByGeneratedCode
Expand All @@ -216,7 +214,7 @@ public static long shortDecimalToBigint(long decimal, long precision, long scale
}

@UsedByGeneratedCode
public static long longDecimalToBigint(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static long longDecimalToBigint(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
try {
return rescale(decimal, DecimalConversions.intScale(-scale)).toLongExact();
Expand All @@ -242,10 +240,10 @@ public static long bigintToShortDecimal(long value, long precision, long scale,
}

@UsedByGeneratedCode
public static Int128 bigintToLongDecimal(long value, long precision, long scale, BigInteger tenToScale)
public static Int128 bigintToLongDecimal(long value, long precision, long scale, Int128 tenToScale)
{
try {
Int128 result = multiply(Int128.valueOf(tenToScale), value);
Int128 result = multiply(tenToScale, value);
if (Decimals.overflows(result, (int) precision)) {
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast BIGINT '%s' to DECIMAL(%s, %s)", value, precision, scale));
}
Expand Down Expand Up @@ -274,7 +272,7 @@ public static long shortDecimalToInteger(long decimal, long precision, long scal
}

@UsedByGeneratedCode
public static long longDecimalToInteger(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static long longDecimalToInteger(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
try {
return toIntExact(rescale(decimal, DecimalConversions.intScale(-scale)).toLongExact());
Expand All @@ -300,10 +298,10 @@ public static long integerToShortDecimal(long value, long precision, long scale,
}

@UsedByGeneratedCode
public static Int128 integerToLongDecimal(long value, long precision, long scale, BigInteger tenToScale)
public static Int128 integerToLongDecimal(long value, long precision, long scale, Int128 tenToScale)
{
try {
Int128 result = multiply(Int128.valueOf(tenToScale), value);
Int128 result = multiply(tenToScale, value);
if (Decimals.overflows(result, (int) precision)) {
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast INTEGER '%s' to DECIMAL(%s, %s)", value, precision, scale));
}
Expand Down Expand Up @@ -332,7 +330,7 @@ public static long shortDecimalToSmallint(long decimal, long precision, long sca
}

@UsedByGeneratedCode
public static long longDecimalToSmallint(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static long longDecimalToSmallint(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
try {
Int128 decimal1 = rescale(decimal, DecimalConversions.intScale(-scale));
Expand All @@ -359,10 +357,10 @@ public static long smallintToShortDecimal(long value, long precision, long scale
}

@UsedByGeneratedCode
public static Int128 smallintToLongDecimal(long value, long precision, long scale, BigInteger tenToScale)
public static Int128 smallintToLongDecimal(long value, long precision, long scale, Int128 tenToScale)
{
try {
Int128 result = multiply(Int128.valueOf(tenToScale), value);
Int128 result = multiply(tenToScale, value);
if (Decimals.overflows(result, (int) precision)) {
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast SMALLINT '%s' to DECIMAL(%s, %s)", value, precision, scale));
}
Expand Down Expand Up @@ -391,7 +389,7 @@ public static long shortDecimalToTinyint(long decimal, long precision, long scal
}

@UsedByGeneratedCode
public static long longDecimalToTinyint(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static long longDecimalToTinyint(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
try {
return SignedBytes.checkedCast(rescale(decimal, DecimalConversions.intScale(-scale)).toLongExact());
Expand All @@ -417,10 +415,10 @@ public static long tinyintToShortDecimal(long value, long precision, long scale,
}

@UsedByGeneratedCode
public static Int128 tinyintToLongDecimal(long value, long precision, long scale, BigInteger tenToScale)
public static Int128 tinyintToLongDecimal(long value, long precision, long scale, Int128 tenToScale)
{
try {
Int128 result = multiply(Int128.valueOf(tenToScale), value);
Int128 result = multiply(tenToScale, value);
if (Decimals.overflows(result, (int) precision)) {
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast TINYINT '%s' to DECIMAL(%s, %s)", value, precision, scale));
}
Expand All @@ -438,7 +436,7 @@ public static double shortDecimalToDouble(long decimal, long precision, long sca
}

@UsedByGeneratedCode
public static double longDecimalToDouble(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static double longDecimalToDouble(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
return DecimalConversions.longDecimalToDouble(decimal, scale);
}
Expand All @@ -450,7 +448,7 @@ public static long shortDecimalToReal(long decimal, long precision, long scale,
}

@UsedByGeneratedCode
public static long longDecimalToReal(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static long longDecimalToReal(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
return DecimalConversions.longDecimalToReal(decimal, scale);
}
Expand All @@ -462,7 +460,7 @@ public static long doubleToShortDecimal(double value, long precision, long scale
}

@UsedByGeneratedCode
public static Int128 doubleToLongDecimal(double value, long precision, long scale, BigInteger tenToScale)
public static Int128 doubleToLongDecimal(double value, long precision, long scale, Int128 tenToScale)
{
return DecimalConversions.doubleToLongDecimal(value, precision, scale);
}
Expand All @@ -474,7 +472,7 @@ public static long realToShortDecimal(long value, long precision, long scale, lo
}

@UsedByGeneratedCode
public static Int128 realToLongDecimal(long value, long precision, long scale, BigInteger tenToScale)
public static Int128 realToLongDecimal(long value, long precision, long scale, Int128 tenToScale)
{
return DecimalConversions.realToLongDecimal(value, precision, scale);
}
Expand Down Expand Up @@ -523,7 +521,7 @@ public static long varcharToShortDecimal(Slice value, long precision, long scale
}

@UsedByGeneratedCode
public static Int128 varcharToLongDecimal(Slice value, long precision, long scale, BigInteger tenToScale)
public static Int128 varcharToLongDecimal(Slice value, long precision, long scale, Int128 tenToScale)
{
BigDecimal result;
String stringValue = value.toString(UTF_8);
Expand All @@ -548,7 +546,7 @@ public static Slice shortDecimalToJson(long decimal, long precision, long scale,
}

@UsedByGeneratedCode
public static Slice longDecimalToJson(Int128 decimal, long precision, long scale, BigInteger tenToScale)
public static Slice longDecimalToJson(Int128 decimal, long precision, long scale, Int128 tenToScale)
{
return decimalToJson(new BigDecimal(decimal.toBigInteger(), DecimalConversions.intScale(scale)));
}
Expand All @@ -568,7 +566,7 @@ private static Slice decimalToJson(BigDecimal bigDecimal)
}

@UsedByGeneratedCode
public static Int128 jsonToLongDecimal(Slice json, long precision, long scale, BigInteger tenToScale)
public static Int128 jsonToLongDecimal(Slice json, long precision, long scale, Int128 tenToScale)
{
try (JsonParser parser = createJsonParser(JSON_FACTORY, json)) {
parser.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,18 @@
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

import java.math.BigDecimal;
import java.math.BigInteger;

import static io.trino.spi.function.OperatorType.SATURATED_FLOOR_CAST;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.Decimals.bigIntegerTenToNth;
import static io.trino.spi.type.Int128Math.POWERS_OF_TEN;
import static io.trino.spi.type.Int128Math.floorDiv;
import static io.trino.spi.type.Int128Math.multiply;
import static io.trino.spi.type.Int128Math.negate;
import static io.trino.spi.type.Int128Math.subtract;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static java.lang.Math.toIntExact;
import static java.math.BigInteger.ONE;
import static java.math.RoundingMode.FLOOR;

public final class DecimalSaturatedFloorCasts
{
Expand Down Expand Up @@ -62,45 +61,46 @@ private DecimalSaturatedFloorCasts() {}
@UsedByGeneratedCode
public static long shortDecimalToShortDecimal(long value, int sourcePrecision, int sourceScale, int resultPrecision, int resultScale)
{
return bigintToBigintFloorSaturatedCast(BigInteger.valueOf(value), sourceScale, resultPrecision, resultScale).longValueExact();
return saturatedCast(Int128.valueOf(value), sourceScale, resultPrecision, resultScale).toLongExact();
}

@UsedByGeneratedCode
public static Int128 shortDecimalToLongDecimal(long value, int sourcePrecision, int sourceScale, int resultPrecision, int resultScale)
{
return Int128.valueOf(bigintToBigintFloorSaturatedCast(BigInteger.valueOf(value), sourceScale, resultPrecision, resultScale));
return saturatedCast(Int128.valueOf(value), sourceScale, resultPrecision, resultScale);
}

@UsedByGeneratedCode
public static long longDecimalToShortDecimal(Int128 value, int sourcePrecision, int sourceScale, int resultPrecision, int resultScale)
{
return bigintToBigintFloorSaturatedCast(value.toBigInteger(), sourceScale, resultPrecision, resultScale).longValueExact();
return saturatedCast(value, sourceScale, resultPrecision, resultScale).toLongExact();
}

@UsedByGeneratedCode
public static Int128 longDecimalToLongDecimal(Int128 value, int sourcePrecision, int sourceScale, int resultPrecision, int resultScale)
{
return Int128.valueOf(bigintToBigintFloorSaturatedCast(value.toBigInteger(), sourceScale, resultPrecision, resultScale));
return saturatedCast(value, sourceScale, resultPrecision, resultScale);
}

private static BigInteger bigintToBigintFloorSaturatedCast(BigInteger value, int sourceScale, int resultPrecision, int resultScale)
private static Int128 saturatedCast(Int128 value, int sourceScale, int resultPrecision, int resultScale)
{
return bigDecimalToBigintFloorSaturatedCast(new BigDecimal(value, sourceScale), resultPrecision, resultScale);
}
int scale = resultScale - sourceScale;
if (scale > 0) {
value = multiply(value, POWERS_OF_TEN[scale]);
}
else if (scale < 0) {
value = floorDiv(value, POWERS_OF_TEN[-scale]);
}

private static BigInteger bigDecimalToBigintFloorSaturatedCast(BigDecimal bigDecimal, int resultPrecision, int resultScale)
{
BigDecimal rescaledValue = bigDecimal.setScale(resultScale, FLOOR);
BigInteger unscaledValue = rescaledValue.unscaledValue();
BigInteger maxUnscaledValue = bigIntegerTenToNth(resultPrecision).subtract(ONE);
if (unscaledValue.compareTo(maxUnscaledValue) > 0) {
Int128 maxUnscaledValue = subtract(POWERS_OF_TEN[resultPrecision], Int128.ONE);
if (value.compareTo(maxUnscaledValue) > 0) {
return maxUnscaledValue;
}
BigInteger minUnscaledValue = maxUnscaledValue.negate();
if (unscaledValue.compareTo(minUnscaledValue) < 0) {
Int128 minUnscaledValue = negate(maxUnscaledValue);
if (value.compareTo(minUnscaledValue) < 0) {
return minUnscaledValue;
}
return unscaledValue;
return value;
}

public static final SqlScalarFunction DECIMAL_TO_BIGINT_SATURATED_FLOOR_CAST = decimalToGenericIntegerTypeSaturatedFloorCast(BIGINT, Long.MIN_VALUE, Long.MAX_VALUE);
Expand Down Expand Up @@ -130,26 +130,28 @@ private static SqlScalarFunction decimalToGenericIntegerTypeSaturatedFloorCast(T
@UsedByGeneratedCode
public static long shortDecimalToGenericIntegerType(long value, int sourceScale, long minValue, long maxValue)
{
return bigIntegerDecimalToGenericIntegerType(BigInteger.valueOf(value), sourceScale, minValue, maxValue);
return saturatedCast(Int128.valueOf(value), sourceScale, minValue, maxValue);
}

@UsedByGeneratedCode
public static long longDecimalToGenericIntegerType(Int128 value, int sourceScale, long minValue, long maxValue)
{
return bigIntegerDecimalToGenericIntegerType(value.toBigInteger(), sourceScale, minValue, maxValue);
return saturatedCast(value, sourceScale, minValue, maxValue);
}

private static long bigIntegerDecimalToGenericIntegerType(BigInteger bigInteger, int sourceScale, long minValue, long maxValue)
private static long saturatedCast(Int128 value, int sourceScale, long minValue, long maxValue)
{
BigDecimal bigDecimal = new BigDecimal(bigInteger, sourceScale);
BigInteger unscaledValue = bigDecimal.setScale(0, FLOOR).unscaledValue();
if (unscaledValue.compareTo(BigInteger.valueOf(maxValue)) > 0) {
if (sourceScale > 0) {
value = floorDiv(value, POWERS_OF_TEN[sourceScale]);
}

if (value.compareTo(Int128.valueOf(maxValue)) > 0) {
return maxValue;
}
if (unscaledValue.compareTo(BigInteger.valueOf(minValue)) < 0) {
if (value.compareTo(Int128.valueOf(minValue)) < 0) {
return minValue;
}
return unscaledValue.longValueExact();
return value.toLongExact();
}

public static final SqlScalarFunction BIGINT_TO_DECIMAL_SATURATED_FLOOR_CAST = genericIntegerTypeToDecimalSaturatedFloorCast(BIGINT);
Expand Down Expand Up @@ -180,12 +182,12 @@ private static SqlScalarFunction genericIntegerTypeToDecimalSaturatedFloorCast(T
@UsedByGeneratedCode
public static long genericIntegerTypeToShortDecimal(long value, int resultPrecision, int resultScale)
{
return bigDecimalToBigintFloorSaturatedCast(BigDecimal.valueOf(value), resultPrecision, resultScale).longValueExact();
return saturatedCast(Int128.valueOf(value), 0, resultPrecision, resultScale).toLongExact();
}

@UsedByGeneratedCode
public static Int128 genericIntegerTypeToLongDecimal(long value, int resultPrecision, int resultScale)
{
return Int128.valueOf(bigDecimalToBigintFloorSaturatedCast(BigDecimal.valueOf(value), resultPrecision, resultScale));
return saturatedCast(Int128.valueOf(value), 0, resultPrecision, resultScale);
}
}
1 change: 1 addition & 0 deletions core/trino-spi/src/main/java/io/trino/spi/type/Int128.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class Int128

public static final Int128 MAX_VALUE = Int128.valueOf(0x7FFF_FFFF_FFFF_FFFFL, 0xFFFF_FFFF_FFFF_FFFFL);
public static final Int128 MIN_VALUE = Int128.valueOf(0x8000_0000_0000_0000L, 0x0000_0000_0000_0000L);
public static final Int128 ONE = Int128.valueOf(0, 1);
public static final Int128 ZERO = Int128.valueOf(0, 0);

private final long high;
Expand Down
Loading