diff --git a/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java b/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java index 11701ef6b495..ef72d30550a6 100644 --- a/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java +++ b/plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java @@ -17,7 +17,6 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; -import io.airlift.slice.Slices; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.SqlMap; @@ -36,11 +35,9 @@ import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.MapType; -import io.trino.spi.type.NumberType; import io.trino.spi.type.RealType; import io.trino.spi.type.RowType; import io.trino.spi.type.SmallintType; -import io.trino.spi.type.SqlNumber; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.TimeType; import io.trino.spi.type.TimeWithTimeZoneType; @@ -48,12 +45,9 @@ import io.trino.spi.type.TimestampType; import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.TinyintType; -import io.trino.spi.type.TrinoNumber; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; -import java.io.IOException; -import java.io.UncheckedIOException; import java.math.BigDecimal; import java.util.List; @@ -113,7 +107,7 @@ public static void validateReturnType(Type type) public static Slice toRowTypeDescriptor(List types) { if (types.isEmpty()) { - SliceOutput output = Slices.allocate(8).getOutput(); + SliceOutput output = new DynamicSliceOutput(8); output.writeInt(TrinoType.ROW.id()); output.writeInt(0); return output.slice(); @@ -124,13 +118,9 @@ public static Slice toRowTypeDescriptor(List types) public static Slice toTypeDescriptor(Type type) { - try (SliceOutput output = new DynamicSliceOutput(64)) { - toTypeDescriptor(type, output); - return output.slice(); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } + SliceOutput output = new DynamicSliceOutput(64); + toTypeDescriptor(type, output); + return output.slice(); } private static void toTypeDescriptor(Type type, SliceOutput output) @@ -166,7 +156,7 @@ private static TrinoType singletonType(Type type) case StandardTypes.TINYINT -> TrinoType.TINYINT; case StandardTypes.DOUBLE -> TrinoType.DOUBLE; case StandardTypes.REAL -> TrinoType.REAL; - case StandardTypes.DECIMAL, StandardTypes.NUMBER -> TrinoType.DECIMAL; + case StandardTypes.DECIMAL -> TrinoType.DECIMAL; case StandardTypes.VARCHAR -> TrinoType.VARCHAR; case StandardTypes.VARBINARY -> TrinoType.VARBINARY; case StandardTypes.DATE -> TrinoType.DATE; @@ -211,17 +201,6 @@ private static void javaToBinary(Type type, Object value, SliceOutput output) : Decimals.toString((Int128) value, decimalType.getScale()); writeVariableSlice(utf8Slice(decimalString), output); } - case NumberType _ -> { - TrinoNumber number = (TrinoNumber) value; - - Slice slice = switch (number.toBigDecimal()) { - case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> utf8Slice(bigDecimal.toString()); - case TrinoNumber.Infinity(boolean negative) -> negative ? utf8Slice("-Infinity") : utf8Slice("Infinity"); - case TrinoNumber.NotANumber() -> utf8Slice("NaN"); - }; - - writeVariableSlice(slice, output); - } case TimeWithTimeZoneType timeType -> { if (timeType.isShort()) { long time = (long) value; @@ -304,15 +283,6 @@ private static void blockToBinary(Type type, Block block, int position, SliceOut : Decimals.toString((Int128) decimalType.getObject(block, position), decimalType.getScale()); writeVariableSlice(utf8Slice(decimalString), output); } - case NumberType numberType -> { - SqlNumber value = (SqlNumber) numberType.getObjectValue(block, position); - Slice slice = switch (value.value()) { - case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> utf8Slice(bigDecimal.toString()); - case TrinoNumber.Infinity(boolean negative) -> negative ? utf8Slice("-Infinity") : utf8Slice("Infinity"); - case TrinoNumber.NotANumber _ -> utf8Slice("NaN"); - }; - writeVariableSlice(slice, output); - } case DateType dateType -> output.writeInt(dateType.getInt(block, position)); case TimeType timeType -> output.writeLong(picosToMicros(timeType.getLong(block, position))); case TimeWithTimeZoneType timeType -> { @@ -415,18 +385,6 @@ public static Object binaryToJava(Type type, SliceInput input) ? encodeShortScaledValue(decimal, decimalType.getScale(), HALF_UP) : encodeScaledValue(decimal, decimalType.getScale(), HALF_UP); } - case NumberType _ -> { - String stringUtf8 = input.readSlice(input.readInt()).toStringUtf8(); - - TrinoNumber.AsBigDecimal number = switch (stringUtf8) { - case "NaN" -> new TrinoNumber.NotANumber(); - case "Infinity" -> new TrinoNumber.Infinity(false); - case "-Infinity" -> new TrinoNumber.Infinity(true); - default -> new TrinoNumber.BigDecimalValue(new BigDecimal(stringUtf8)); - }; - - yield TrinoNumber.from(number); - } case TimeType timeType -> { long micros = roundMicros(input.readLong(), timeType.getPrecision()) % MICROSECONDS_PER_DAY; yield micros * PICOSECONDS_PER_MICROSECOND; diff --git a/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java b/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java index df5ef341f5ad..c8dc1779a3d6 100644 --- a/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java +++ b/plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java @@ -35,7 +35,6 @@ import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -752,38 +751,6 @@ SELECT bad_bigint_return() "TypeError: 'str' object cannot be interpreted as an integer"); } - @Test - public void testTypeNumber() - { - String query = - """ - WITH FUNCTION multiply(x number, y number) - RETURNS number - LANGUAGE PYTHON - WITH (handler = 'multiply') - AS $$ - from decimal import Decimal - def multiply(x, y): - return x * y * Decimal("100000000000000000000.000000000000000000000000000000000000001") - $$ - """; - - assertThat(assertions.query( - query + "SELECT multiply(NUMBER '1.12345', NUMBER '2.54321')")) - .matches("VALUES NUMBER '2.8571692745E+20'"); - - // TODO: https://github.com/trinodb/trino-wasm-python/pull/11 - assertThatThrownBy(() -> assertThat(assertions.query( - query + "SELECT multiply(NUMBER 'NaN', NUMBER '2.54321')")) - .matches("VALUES NUMBER 'NaN'")) - .hasMessageContaining("ValueError: Decimal is not finite: NaN"); - - assertThatThrownBy(() -> assertThat(assertions.query( - query + "SELECT multiply(NUMBER '-Infinity', NUMBER '2.54321')")) - .matches("VALUES NUMBER 'NaN'")) - .hasMessageContaining("ValueError: Decimal is not finite: -Infinity"); - } - @Test public void testTypeInteger() {