diff --git a/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java b/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java index ef72d30550a6..e99b2d9885c3 100644 --- a/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java +++ b/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java @@ -77,6 +77,7 @@ import static io.trino.spi.type.TypeUtils.writeNativeValue; import static java.lang.Math.toIntExact; import static java.math.RoundingMode.HALF_UP; +import static java.util.Objects.requireNonNullElse; final class TrinoTypes { @@ -381,9 +382,17 @@ public static Object binaryToJava(Type type, SliceInput input) case MapType mapType -> binaryMapToJava(mapType, input); case DecimalType decimalType -> { BigDecimal decimal = new BigDecimal(input.readSlice(input.readInt()).toStringUtf8()); - yield decimalType.isShort() - ? encodeShortScaledValue(decimal, decimalType.getScale(), HALF_UP) - : encodeScaledValue(decimal, decimalType.getScale(), HALF_UP); + try { + yield decimalType.isShort() + ? encodeShortScaledValue(decimal, decimalType.getScale(), HALF_UP) + : encodeScaledValue(decimal, decimalType.getScale(), HALF_UP); + } + catch (ArithmeticException e) { + throw new TrinoException( + FUNCTION_IMPLEMENTATION_ERROR, + "Function result cannot be converted to %s: %s".formatted(decimalType.getDisplayName(), requireNonNullElse(e.getMessage(), e)), + e); + } } case TimeType timeType -> { long micros = roundMicros(input.readLong(), timeType.getPrecision()) % MICROSECONDS_PER_DAY; diff --git a/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java b/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java index c8dc1779a3d6..f0401a121956 100644 --- a/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java +++ b/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java @@ -898,6 +898,37 @@ assert str(x) == '12345678901234567890.12340' SELECT test_decimal_long(12345678901234567890.1234) """)) .matches("VALUES cast(1524148134430814813443.07447 AS decimal(38, 5))"); + + String realToDecimalInPython = + """ + WITH FUNCTION test_cast_real_to_decimal(x real) + RETURNS decimal(38, 5) + LANGUAGE PYTHON + WITH (handler = 'test') + AS $$ + from decimal import Decimal + def test(x): + return Decimal.from_float(x) + $$ + """; + + // underflow + assertThat(assertions.query(realToDecimalInPython + "SELECT test_cast_real_to_decimal(REAL '1e-17')")) + .matches("VALUES CAST('0' AS decimal(38, 5))"); + + // overflow + assertThat(assertions.query(realToDecimalInPython + "SELECT test_cast_real_to_decimal(REAL '1e+34')")) + .failure().hasMessage("Function result cannot be converted to decimal(38,5): Decimal overflow"); + + // NaN + assertThat(assertions.query( realToDecimalInPython + "SELECT test_cast_real_to_decimal(REAL 'NaN')")) + .failure().hasMessage("Failed to convert Python result type 'decimal.Decimal' to Trino type DECIMAL: ValueError: Decimal is not finite: NaN"); + + // Infinity + assertThat(assertions.query(realToDecimalInPython + "SELECT test_cast_real_to_decimal(REAL '-Infinity')")) + .failure().hasMessage("Failed to convert Python result type 'decimal.Decimal' to Trino type DECIMAL: ValueError: Decimal is not finite: -Infinity"); + assertThat(assertions.query( realToDecimalInPython + "SELECT test_cast_real_to_decimal(REAL '+Infinity')")) + .failure().hasMessage("Failed to convert Python result type 'decimal.Decimal' to Trino type DECIMAL: ValueError: Decimal is not finite: Infinity"); } @Test