diff --git a/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java b/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java index ef72d30550a6..11701ef6b495 100644 --- a/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java +++ b/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.SqlMap; @@ -35,9 +36,11 @@ import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.MapType; +import io.trino.spi.type.NumberType; import io.trino.spi.type.RealType; import io.trino.spi.type.RowType; import io.trino.spi.type.SmallintType; +import io.trino.spi.type.SqlNumber; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.TimeType; import io.trino.spi.type.TimeWithTimeZoneType; @@ -45,9 +48,12 @@ import io.trino.spi.type.TimestampType; import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.TinyintType; +import io.trino.spi.type.TrinoNumber; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import java.io.IOException; +import java.io.UncheckedIOException; import java.math.BigDecimal; import java.util.List; @@ -107,7 +113,7 @@ public static void validateReturnType(Type type) public static Slice toRowTypeDescriptor(List types) { if (types.isEmpty()) { - SliceOutput output = new DynamicSliceOutput(8); + SliceOutput output = Slices.allocate(8).getOutput(); output.writeInt(TrinoType.ROW.id()); output.writeInt(0); return output.slice(); @@ -118,9 +124,13 @@ public static Slice toRowTypeDescriptor(List types) public static Slice toTypeDescriptor(Type type) { - SliceOutput output = new DynamicSliceOutput(64); - toTypeDescriptor(type, output); - return output.slice(); + try (SliceOutput output = new DynamicSliceOutput(64)) { + toTypeDescriptor(type, output); + return output.slice(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } } private static void toTypeDescriptor(Type type, SliceOutput output) @@ -156,7 +166,7 @@ private static TrinoType singletonType(Type type) case StandardTypes.TINYINT -> TrinoType.TINYINT; case StandardTypes.DOUBLE -> TrinoType.DOUBLE; case StandardTypes.REAL -> TrinoType.REAL; - case StandardTypes.DECIMAL -> TrinoType.DECIMAL; + case StandardTypes.DECIMAL, StandardTypes.NUMBER -> TrinoType.DECIMAL; case StandardTypes.VARCHAR -> TrinoType.VARCHAR; case StandardTypes.VARBINARY -> TrinoType.VARBINARY; case StandardTypes.DATE -> TrinoType.DATE; @@ -201,6 +211,17 @@ private static void javaToBinary(Type type, Object value, SliceOutput output) : Decimals.toString((Int128) value, decimalType.getScale()); writeVariableSlice(utf8Slice(decimalString), output); } + case NumberType _ -> { + TrinoNumber number = (TrinoNumber) value; + + Slice slice = switch (number.toBigDecimal()) { + case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> utf8Slice(bigDecimal.toString()); + case TrinoNumber.Infinity(boolean negative) -> negative ? utf8Slice("-Infinity") : utf8Slice("Infinity"); + case TrinoNumber.NotANumber() -> utf8Slice("NaN"); + }; + + writeVariableSlice(slice, output); + } case TimeWithTimeZoneType timeType -> { if (timeType.isShort()) { long time = (long) value; @@ -283,6 +304,15 @@ private static void blockToBinary(Type type, Block block, int position, SliceOut : Decimals.toString((Int128) decimalType.getObject(block, position), decimalType.getScale()); writeVariableSlice(utf8Slice(decimalString), output); } + case NumberType numberType -> { + SqlNumber value = (SqlNumber) numberType.getObjectValue(block, position); + Slice slice = switch (value.value()) { + case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> utf8Slice(bigDecimal.toString()); + case TrinoNumber.Infinity(boolean negative) -> negative ? utf8Slice("-Infinity") : utf8Slice("Infinity"); + case TrinoNumber.NotANumber _ -> utf8Slice("NaN"); + }; + writeVariableSlice(slice, output); + } case DateType dateType -> output.writeInt(dateType.getInt(block, position)); case TimeType timeType -> output.writeLong(picosToMicros(timeType.getLong(block, position))); case TimeWithTimeZoneType timeType -> { @@ -385,6 +415,18 @@ public static Object binaryToJava(Type type, SliceInput input) ? encodeShortScaledValue(decimal, decimalType.getScale(), HALF_UP) : encodeScaledValue(decimal, decimalType.getScale(), HALF_UP); } + case NumberType _ -> { + String stringUtf8 = input.readSlice(input.readInt()).toStringUtf8(); + + TrinoNumber.AsBigDecimal number = switch (stringUtf8) { + case "NaN" -> new TrinoNumber.NotANumber(); + case "Infinity" -> new TrinoNumber.Infinity(false); + case "-Infinity" -> new TrinoNumber.Infinity(true); + default -> new TrinoNumber.BigDecimalValue(new BigDecimal(stringUtf8)); + }; + + yield TrinoNumber.from(number); + } case TimeType timeType -> { long micros = roundMicros(input.readLong(), timeType.getPrecision()) % MICROSECONDS_PER_DAY; yield micros * PICOSECONDS_PER_MICROSECOND; diff --git a/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java b/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java index c8dc1779a3d6..df5ef341f5ad 100644 --- a/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java +++ b/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java @@ -35,6 +35,7 @@ import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -751,6 +752,38 @@ SELECT bad_bigint_return() "TypeError: 'str' object cannot be interpreted as an integer"); } + @Test + public void testTypeNumber() + { + String query = + """ + WITH FUNCTION multiply(x number, y number) + RETURNS number + LANGUAGE PYTHON + WITH (handler = 'multiply') + AS $$ + from decimal import Decimal + def multiply(x, y): + return x * y * Decimal("100000000000000000000.000000000000000000000000000000000000001") + $$ + """; + + assertThat(assertions.query( + query + "SELECT multiply(NUMBER '1.12345', NUMBER '2.54321')")) + .matches("VALUES NUMBER '2.8571692745E+20'"); + + // TODO: https://github.com/trinodb/trino-wasm-python/pull/11 + assertThatThrownBy(() -> assertThat(assertions.query( + query + "SELECT multiply(NUMBER 'NaN', NUMBER '2.54321')")) + .matches("VALUES NUMBER 'NaN'")) + .hasMessageContaining("ValueError: Decimal is not finite: NaN"); + + assertThatThrownBy(() -> assertThat(assertions.query( + query + "SELECT multiply(NUMBER '-Infinity', NUMBER '2.54321')")) + .matches("VALUES NUMBER 'NaN'")) + .hasMessageContaining("ValueError: Decimal is not finite: -Infinity"); + } + @Test public void testTypeInteger() {